Skip to content

Commit c79f0ca

Browse files
#1100 working on casadi solver sensitivities
1 parent b887515 commit c79f0ca

File tree

4 files changed

+402
-377
lines changed

4 files changed

+402
-377
lines changed

pybamm/solvers/base_solver.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -510,20 +510,27 @@ def _set_initial_conditions(self, model, inputs, update_rhs):
510510
Whether to update the rhs. True for 'solve', False for 'step'.
511511
512512
"""
513+
# Make inputs symbolic if calculating sensitivities with casadi
514+
if self.sensitivity == "casadi":
515+
symbolic_inputs = casadi.MX.sym(
516+
"inputs", casadi.vertcat(*inputs.values()).shape[0]
517+
)
518+
else:
519+
symbolic_inputs = inputs
513520
if self.algebraic_solver is True:
514521
# Don't update model.y0
515522
return None
516523
elif len(model.algebraic) == 0:
517524
if update_rhs is True:
518525
# Recalculate initial conditions for the rhs equations
519-
model.y0 = model.init_eval(inputs)
526+
y0 = model.init_eval(symbolic_inputs)
520527
else:
521528
# Don't update model.y0
522529
return None
523530
else:
524531
if update_rhs is True:
525532
# Recalculate initial conditions for the rhs equations
526-
y0_from_inputs = model.init_eval(inputs)
533+
y0_from_inputs = model.init_eval(symbolic_inputs)
527534
# Reuse old solution for algebraic equations
528535
y0_from_model = model.y0
529536
len_rhs = model.len_rhs
@@ -534,7 +541,12 @@ def _set_initial_conditions(self, model, inputs, update_rhs):
534541
model.y0 = casadi.vertcat(
535542
y0_from_inputs[:len_rhs], y0_from_model[len_rhs:]
536543
)
537-
model.y0 = self.calculate_consistent_state(model, 0, inputs)
544+
y0 = self.calculate_consistent_state(model, 0, inputs)
545+
# Make y0 a function of inputs if doing symbolic with casadi
546+
if self.sensitivity == "casadi":
547+
model.y0 = casadi.Function("y0", [symbolic_inputs], [y0])
548+
else:
549+
model.y0 = y0
538550

539551
def calculate_consistent_state(self, model, time=0, inputs=None):
540552
"""

pybamm/solvers/casadi_algebraic_solver.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,11 @@ def _integrate(self, model, t_eval, inputs=None):
107107
else:
108108
# Set up
109109
t_sym = casadi.MX.sym("t")
110+
y0_diff_sym = casadi.MX.sym("y0_diff", y0_diff.shape[0])
110111
y_alg_sym = casadi.MX.sym("y_alg", y0_alg.shape[0])
111-
y_sym = casadi.vertcat(y0_diff, y_alg_sym)
112+
y_sym = casadi.vertcat(y0_diff_sym, y_alg_sym)
112113

113-
t_and_inputs_sym = casadi.vertcat(t_sym, symbolic_inputs)
114+
t_y0diff_inputs_sym = casadi.vertcat(t_sym, y0_diff_sym, symbolic_inputs)
114115
alg = model.casadi_algebraic(t_sym, y_sym, symbolic_inputs)
115116

116117
# Set constraints vector in the casadi format
@@ -126,7 +127,7 @@ def _integrate(self, model, t_eval, inputs=None):
126127
roots = casadi.rootfinder(
127128
"roots",
128129
"newton",
129-
dict(x=y_alg_sym, p=t_and_inputs_sym, g=alg),
130+
dict(x=y_alg_sym, p=t_y0diff_inputs_sym, g=alg),
130131
{
131132
**self.extra_options,
132133
"abstol": self.tol,
@@ -155,12 +156,12 @@ def _integrate(self, model, t_eval, inputs=None):
155156
# If doing sensitivity with casadi, evaluate with symbolic inputs
156157
# Otherwise, evaluate with actual inputs
157158
if self.sensitivity == "casadi":
158-
t_eval_and_inputs = casadi.vertcat(t, symbolic_inputs)
159+
t_y0_diff_inputs = casadi.vertcat(t, y0_diff, symbolic_inputs)
159160
else:
160-
t_eval_and_inputs = casadi.vertcat(t, inputs)
161+
t_y0_diff_inputs = casadi.vertcat(t, y0_diff, inputs)
161162
# Solve
162163
try:
163-
y_alg_sol = roots(y0_alg, t_eval_and_inputs)
164+
y_alg_sol = roots(y0_alg, t_y0_diff_inputs)
164165
success = True
165166
message = None
166167
# Check final output

pybamm/solvers/casadi_solver.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def _integrate(self, model, t_eval, inputs=None):
208208

209209
if self.mode == "safe":
210210
# update integrator with the grid
211-
self.create_integrator(model, inputs, t_window)
211+
self.create_integrator(model, inputs_dict, t_window)
212212
# Try to solve with the current global step, if it fails then
213213
# halve the step size and try again.
214214
try:
@@ -347,7 +347,6 @@ def create_integrator(self, model, inputs_dict, t_eval=None):
347347
self.integrators[model] = (integrator, use_grid)
348348
return integrator
349349
else:
350-
y0 = model.y0
351350
rhs = model.casadi_rhs
352351
algebraic = model.casadi_algebraic
353352

@@ -370,6 +369,12 @@ def create_integrator(self, model, inputs_dict, t_eval=None):
370369
# set up and solve
371370
t = casadi.MX.sym("t")
372371
p = casadi.MX.sym("p", inputs.shape[0])
372+
# If the initial conditions depend on inputs, evaluate the function
373+
if isinstance(model.y0, casadi.Function):
374+
y0 = model.y0(p)
375+
else:
376+
y0 = model.y0
377+
373378
y_diff = casadi.MX.sym("y_diff", rhs(0, y0, p).shape[0])
374379

375380
if use_grid is False:
@@ -420,6 +425,13 @@ def _run_integrator(self, model, y0, inputs_dict, t_eval):
420425
else:
421426
inputs_eval = inputs
422427
integrator, use_grid = self.integrators[model]
428+
429+
# If the initial conditions depend on inputs, evaluate the function
430+
if isinstance(y0, casadi.Function):
431+
y0 = y0(symbolic_inputs)
432+
else:
433+
y0 = y0
434+
423435
# Split up initial conditions into differential and algebraic
424436
# Check y0 to see if it includes sensitivities
425437
if model.len_rhs_and_alg == y0.shape[0]:

0 commit comments

Comments
 (0)