Skip to content

Commit 943f418

Browse files
Merge pull request #1321 from brosaplanella/issue-1320-casadi-error-extrap-events
Issue 1320 casadi error extrap events
2 parents 562e14a + b4a584e commit 943f418

File tree

3 files changed

+96
-65
lines changed

3 files changed

+96
-65
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
## Bug fixes
2121

22+
- Fixed a bug in `CasadiSolver` safe mode which crashed when there were extrapolation events but no termination events ([#1321](https://github.com/pybamm-team/PyBaMM/pull/1321))
2223
- When an `Interpolant` is extrapolated an error is raised for `CasadiSolver` (and a warning is raised for the other solvers) ([#1315](https://github.com/pybamm-team/PyBaMM/pull/1315))
2324
- Fixed `Simulation` and `model.new_copy` to fix a bug where changes to the model were overwritten ([#1278](https://github.com/pybamm-team/PyBaMM/pull/1278))
2425

pybamm/solvers/casadi_solver.py

+74-65
Original file line numberDiff line numberDiff line change
@@ -142,37 +142,42 @@ def _integrate(self, model, t_eval, inputs=None):
142142
# Step-and-check
143143
t = t_eval[0]
144144
t_f = t_eval[-1]
145-
init_event_signs = np.sign(
146-
np.concatenate(
147-
[event(t, y0, inputs) for event in model.terminate_events_eval]
145+
if model.terminate_events_eval:
146+
init_event_signs = np.sign(
147+
np.concatenate(
148+
[event(t, y0, inputs) for event in model.terminate_events_eval]
149+
)
148150
)
149-
)
151+
else:
152+
init_event_signs = np.sign([])
150153

151-
extrap_event = [
152-
event(t, y0, inputs)
153-
for event in model.interpolant_extrapolation_events_eval
154-
]
155-
156-
if extrap_event:
157-
if (np.concatenate(extrap_event) < self.extrap_tol).any():
158-
extrap_event_names = []
159-
for event in model.events:
160-
if (
161-
event.event_type
162-
== pybamm.EventType.INTERPOLANT_EXTRAPOLATION
163-
and (
164-
event.expression.evaluate(t, y0.full(), inputs=inputs,)
165-
< self.extrap_tol
166-
).any()
167-
):
168-
extrap_event_names.append(event.name[12:])
154+
if model.interpolant_extrapolation_events_eval:
155+
extrap_event = [
156+
event(t, y0, inputs)
157+
for event in model.interpolant_extrapolation_events_eval
158+
]
159+
if extrap_event:
160+
if (np.concatenate(extrap_event) < self.extrap_tol).any():
161+
extrap_event_names = []
162+
for event in model.events:
163+
if (
164+
event.event_type
165+
== pybamm.EventType.INTERPOLANT_EXTRAPOLATION
166+
and (
167+
event.expression.evaluate(
168+
t, y0.full(), inputs=inputs,
169+
)
170+
< self.extrap_tol
171+
).any()
172+
):
173+
extrap_event_names.append(event.name[12:])
169174

170-
raise pybamm.SolverError(
171-
"CasADI solver failed because the following interpolation "
172-
"bounds were exceeded at the initial conditions: {}. "
173-
"You may need to provide additional interpolation points "
174-
"outside these bounds.".format(extrap_event_names)
175-
)
175+
raise pybamm.SolverError(
176+
"CasADI solver failed because the following interpolation "
177+
"bounds were exceeded at the initial conditions: {}. "
178+
"You may need to provide additional interpolation points "
179+
"outside these bounds.".format(extrap_event_names)
180+
)
176181

177182
pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))
178183

@@ -240,44 +245,48 @@ def _integrate(self, model, t_eval, inputs=None):
240245
)
241246
)
242247
# Check most recent y to see if any events have been crossed
243-
new_event_signs = np.sign(
244-
np.concatenate(
245-
[
246-
event(t, current_step_sol.y[:, -1], inputs)
247-
for event in model.terminate_events_eval
248-
]
249-
)
250-
)
251-
252-
extrap_event = [
253-
event(t, current_step_sol.y[:, -1], inputs=inputs)
254-
for event in model.interpolant_extrapolation_events_eval
255-
]
256-
257-
if extrap_event:
258-
if (np.concatenate(extrap_event) < self.extrap_tol).any():
259-
extrap_event_names = []
260-
for event in model.events:
261-
if (
262-
event.event_type
263-
== pybamm.EventType.INTERPOLANT_EXTRAPOLATION
264-
and (
265-
event.expression.evaluate(
266-
t,
267-
current_step_sol.y[:, -1].full(),
268-
inputs=inputs,
269-
)
270-
< self.extrap_tol
271-
).any()
272-
):
273-
extrap_event_names.append(event.name[12:])
274-
275-
raise pybamm.SolverError(
276-
"CasADI solver failed because the following interpolation "
277-
"bounds were exceeded: {}. You may need to provide "
278-
"additional interpolation points outside these "
279-
"bounds.".format(extrap_event_names)
248+
if model.terminate_events_eval:
249+
new_event_signs = np.sign(
250+
np.concatenate(
251+
[
252+
event(t, current_step_sol.y[:, -1], inputs)
253+
for event in model.terminate_events_eval
254+
]
280255
)
256+
)
257+
else:
258+
new_event_signs = np.sign([])
259+
260+
if model.interpolant_extrapolation_events_eval:
261+
extrap_event = [
262+
event(t, current_step_sol.y[:, -1], inputs=inputs)
263+
for event in model.interpolant_extrapolation_events_eval
264+
]
265+
266+
if extrap_event:
267+
if (np.concatenate(extrap_event) < self.extrap_tol).any():
268+
extrap_event_names = []
269+
for event in model.events:
270+
if (
271+
event.event_type
272+
== pybamm.EventType.INTERPOLANT_EXTRAPOLATION
273+
and (
274+
event.expression.evaluate(
275+
t,
276+
current_step_sol.y[:, -1].full(),
277+
inputs=inputs,
278+
)
279+
< self.extrap_tol
280+
).any()
281+
):
282+
extrap_event_names.append(event.name[12:])
283+
284+
raise pybamm.SolverError(
285+
"CasADI solver failed because the following "
286+
"interpolation bounds were exceeded: {}. You may need "
287+
"to provide additional interpolation points outside "
288+
"these bounds.".format(extrap_event_names)
289+
)
281290

282291
# Exit loop if the sign of an event changes
283292
# Locate the event time using a root finding algorithm and

tests/unit/test_solvers/test_casadi_solver.py

+21
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,27 @@ def test_interpolant_extrapolate(self):
467467
with self.assertRaisesRegex(pybamm.SolverError, "interpolation bounds"):
468468
sim.solve()
469469

470+
def test_casadi_safe_no_termination(self):
471+
model = pybamm.BaseModel()
472+
v = pybamm.Variable("v")
473+
model.rhs = {v: -1}
474+
model.initial_conditions = {v: 1}
475+
model.events.append(
476+
pybamm.Event(
477+
"Triggered event", v - 0.5, pybamm.EventType.INTERPOLANT_EXTRAPOLATION,
478+
)
479+
)
480+
model.events.append(
481+
pybamm.Event(
482+
"Ignored event", v + 10, pybamm.EventType.INTERPOLANT_EXTRAPOLATION,
483+
)
484+
)
485+
solver = pybamm.CasadiSolver(mode="safe")
486+
solver.set_up(model)
487+
488+
with self.assertRaisesRegex(pybamm.SolverError, "interpolation bounds"):
489+
solver.solve(model, t_eval=[0, 1])
490+
470491

471492
class TestCasadiSolverSensitivity(unittest.TestCase):
472493
def test_solve_with_symbolic_input(self):

0 commit comments

Comments
 (0)