Skip to content

Commit 589ec31

Browse files
#775 get tom's test working (hacky)
1 parent 1f0cbe2 commit 589ec31

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

pybamm/solvers/casadi_solver.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -242,14 +242,23 @@ def integrate_casadi(self, rhs, algebraic, y0, t_eval, inputs=None):
242242
# set up and solve
243243
t = casadi.MX.sym("t")
244244
u = casadi.vertcat(*[x for x in inputs.values()])
245-
y_diff = casadi.MX.sym("y_diff", rhs(0, y0, u).shape[0])
245+
y_diff = self.y_diff
246246
problem = {"t": t, "x": y_diff}
247247
if algebraic is None:
248-
problem.update({"ode": rhs(t, y_diff, u)})
248+
y_casadi_w_ext = casadi.vertcat(y_diff, self.y_ext[y_diff.shape[0] :])
249+
problem.update({"ode": rhs(t, y_casadi_w_ext, u)})
249250
else:
250-
y_alg = casadi.MX.sym("y_alg", algebraic(0, y0, u).shape[0])
251-
y = casadi.vertcat(y_diff, y_alg)
252-
problem.update({"z": y_alg, "ode": rhs(t, y, u), "alg": algebraic(t, y, u)})
251+
y_alg = self.y_alg
252+
y_casadi_w_ext = casadi.vertcat(
253+
y_diff, y_alg, self.y_ext[y_diff.shape[0] + y_alg.shape[0] :]
254+
)
255+
problem.update(
256+
{
257+
"z": y_alg,
258+
"ode": rhs(t, y_casadi_w_ext, u),
259+
"alg": algebraic(t, y_casadi_w_ext, u),
260+
}
261+
)
253262
integrator = casadi.integrator("F", self.method, problem, options)
254263
try:
255264
# Try solving

pybamm/solvers/dae_solver.py

+2
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ def get_event_class(event):
340340
self.events = model.events
341341
self.event_funs = [get_event_class(event) for event in casadi_events.values()]
342342
self.jacobian = jacobian
343+
self.y_diff = y_diff
344+
self.y_alg = y_alg
343345

344346
# Save CasADi functions for the CasADi solver
345347
# Note: when we pass to casadi the ode part of the problem must be in explicit

pybamm/solvers/ode_solver.py

+2
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ def get_event_class(event):
232232

233233
# Add the solver attributes
234234
self.y0 = y0
235+
self.y_diff = y_casadi
236+
self.y_casadi_w_ext = y_casadi_w_ext
235237
self.dydt = DydtCasadi(model, concatenated_rhs_fn)
236238
self.events = model.events
237239
self.event_funs = [get_event_class(event) for event in casadi_events.values()]

0 commit comments

Comments
 (0)