Skip to content

Commit 33aff3a

Browse files
committed
#853 fix tests
1 parent 9973692 commit 33aff3a

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

pybamm/solvers/casadi_solver.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,11 @@ def _integrate(self, model, t_eval, inputs=None):
135135
t_f = t_eval[-1]
136136
init_event_signs = np.sign(
137137
np.concatenate(
138-
[
139-
event(t, model.y0, inputs)
140-
for event in model.terminate_events_eval
141-
]
138+
[event(t, y0, inputs) for event in model.terminate_events_eval]
142139
)
143140
)
144141
pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))
145-
y0 = model.y0
142+
146143
# Initialize solution
147144
solution = pybamm.Solution(np.array([t]), y0[:, np.newaxis])
148145
solution.solve_time = 0
@@ -273,6 +270,9 @@ def event_fun(t):
273270
y0 = solution.y[:, -1]
274271
return solution
275272
elif self.mode == "old safe":
273+
y0 = model.y0
274+
if isinstance(y0, casadi.DM):
275+
y0 = y0.full().flatten()
276276
# Step-and-check
277277
t = t_eval[0]
278278
init_event_signs = np.sign(

tests/unit/test_solvers/test_casadi_solver.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ def test_model_solver_failure(self):
8181
disc = pybamm.Discretisation()
8282
disc.process_model(model)
8383

84-
solver = pybamm.CasadiSolver(regularity_check=False)
85-
solver_old = pybamm.CasadiSolver(mode="old safe", regularity_check=False)
84+
solver = pybamm.CasadiSolver(extra_options_call={"regularity_check": False})
85+
solver_old = pybamm.CasadiSolver(
86+
mode="old safe", extra_options_call={"regularity_check": False}
87+
)
8688
# Solve with failure at t=2
8789
t_eval = np.linspace(0, 20, 100)
8890
with self.assertRaises(pybamm.SolverError):

0 commit comments

Comments
 (0)