@@ -963,38 +963,45 @@ def solve(
963963 return solutions
964964
965965 @staticmethod
966- def _get_discontinuity_start_end_indices (model , inputs , t_eval ):
967- if not model .discontinuity_events_eval :
968- pybamm .logger .verbose ("No discontinuity events found" )
969- return [0 ], [len (t_eval )], t_eval
970-
971- # Calculate discontinuities
972- discontinuities = [
973- # Assuming that discontinuities do not depend on
974- # input parameters when len(input_list) > 1, only
975- # `inputs` is passed to `evaluate`.
976- # See https://github.com/pybamm-team/PyBaMM/pull/1261
977- event .expression .evaluate (inputs = inputs )
978- for event in model .discontinuity_events_eval
979- ]
980-
966+ def _sort_and_clean_discontinuities (discontinuities , t_eval ):
981967 # make sure they are increasing in time
982968 discontinuities = sorted (discontinuities )
983969
984- # remove any identical discontinuities
970+ # remove any identical discontinuities, and also unwanted discontinuities
971+ # at the beginning of the integration (see https://github.com/pybamm-team/PyBaMM/pull/5075)
985972 discontinuities = [
986973 v
987974 for i , v in enumerate (discontinuities )
988975 if (
989976 i == len (discontinuities ) - 1
990977 or discontinuities [i ] < discontinuities [i + 1 ]
991978 )
992- and v > t_eval [0 ]
979+ and v > t_eval [0 ]
993980 ]
994981
995982 # remove any discontinuities after end of t_eval
996983 discontinuities = [v for v in discontinuities if v < t_eval [- 1 ]]
997984
985+ return discontinuities
986+
987+ def _get_discontinuity_start_end_indices (self , model , inputs , t_eval ):
988+ if not model .discontinuity_events_eval :
989+ pybamm .logger .verbose ("No discontinuity events found" )
990+ return [0 ], [len (t_eval )], t_eval
991+
992+ # Calculate discontinuities
993+ discontinuities = [
994+ # Assuming that discontinuities do not depend on
995+ # input parameters when len(input_list) > 1, only
996+ # `inputs` is passed to `evaluate`.
997+ # See https://github.com/pybamm-team/PyBaMM/pull/1261
998+ event .expression .evaluate (inputs = inputs )
999+ for event in model .discontinuity_events_eval
1000+ ]
1001+
1002+ # sort and remove unwanted discontinuities
1003+ discontinuities = self ._sort_and_clean_discontinuities (discontinuities , t_eval )
1004+
9981005 pybamm .logger .verbose (f"Discontinuity events found at t = { discontinuities } " )
9991006 if isinstance (inputs , list ):
10001007 raise pybamm .SolverError (
0 commit comments