Skip to content

Commit b887515

Browse files
#1100 working on casadi solver
1 parent 26447b3 commit b887515

File tree

4 files changed

+876
-999
lines changed

4 files changed

+876
-999
lines changed

pybamm/solvers/casadi_algebraic_solver.py

+59-45
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def __init__(self, tol=1e-6, extra_options=None, sensitivity=None):
4040
self.extra_options = extra_options or {}
4141
pybamm.citations.register("Andersson2019")
4242

43+
self.rootfinders = {}
44+
self.y_sols = {}
45+
4346
@property
4447
def tol(self):
4548
return self._tol
@@ -62,24 +65,12 @@ def _integrate(self, model, t_eval, inputs=None):
6265
Any input parameters to pass to the model when solving.
6366
"""
6467
# Record whether there are any symbolic inputs
65-
inputs = inputs or {}
68+
inputs_dict = inputs or {}
6669
# Create casadi objects for the root-finder
67-
inputs_dict = inputs
6870
inputs = casadi.vertcat(*[v for v in inputs.values()])
6971

70-
if self.sensitivity == "casadi" and inputs_dict != {}:
71-
# Create symbolic inputs for sensitivity analysis
72-
symbolic_inputs_list = []
73-
for name, value in inputs_dict.items():
74-
if isinstance(value, numbers.Number):
75-
symbolic_inputs_list.append(casadi.MX.sym(name))
76-
else:
77-
symbolic_inputs_list.append(casadi.MX.sym(name, value.shape[0]))
78-
symbolic_inputs = casadi.vertcat(*[p for p in symbolic_inputs_list])
79-
inputs_for_alg = symbolic_inputs
80-
else:
81-
symbolic_inputs = casadi.DM()
82-
inputs_for_alg = inputs
72+
# Create symbolic inputs
73+
symbolic_inputs = casadi.MX.sym("inputs", inputs.shape[0])
8374

8475
y0 = model.y0
8576
# The casadi algebraic solver can read rhs equations, but leaves them unchanged
@@ -101,34 +92,50 @@ def _integrate(self, model, t_eval, inputs=None):
10192

10293
y_alg = None
10394

104-
# Set up
105-
t_sym = casadi.MX.sym("t")
106-
y_alg_sym = casadi.MX.sym("y_alg", y0_alg.shape[0])
107-
y_sym = casadi.vertcat(y0_diff, y_alg_sym)
108-
109-
t_and_inputs_sym = casadi.vertcat(t_sym, symbolic_inputs)
110-
alg = model.casadi_algebraic(t_sym, y_sym, inputs_for_alg)
111-
112-
# Set constraints vector in the casadi format
113-
# Constrain the unknowns. 0 (default): no constraint on ui, 1: ui >= 0.0,
114-
# -1: ui <= 0.0, 2: ui > 0.0, -2: ui < 0.0.
115-
constraints = np.zeros_like(model.bounds[0], dtype=int)
116-
# If the lower bound is positive then the variable must always be positive
117-
constraints[model.bounds[0] >= 0] = 1
118-
# If the upper bound is negative then the variable must always be negative
119-
constraints[model.bounds[1] <= 0] = -1
120-
121-
# Set up rootfinder
122-
roots = casadi.rootfinder(
123-
"roots",
124-
"newton",
125-
dict(x=y_alg_sym, p=t_and_inputs_sym, g=alg),
126-
{
127-
**self.extra_options,
128-
"abstol": self.tol,
129-
"constraints": list(constraints[len_rhs:]),
130-
},
131-
)
95+
if model in self.rootfinders:
96+
if self.sensitivity == "casadi":
97+
# Reuse (symbolic) solution with new inputs
98+
y_sol = self.y_sols[model]
99+
return pybamm.Solution(
100+
t_eval,
101+
y_sol,
102+
termination="success",
103+
model=model,
104+
inputs=inputs_dict,
105+
)
106+
roots = self.rootfinders[model]
107+
else:
108+
# Set up
109+
t_sym = casadi.MX.sym("t")
110+
y_alg_sym = casadi.MX.sym("y_alg", y0_alg.shape[0])
111+
y_sym = casadi.vertcat(y0_diff, y_alg_sym)
112+
113+
t_and_inputs_sym = casadi.vertcat(t_sym, symbolic_inputs)
114+
alg = model.casadi_algebraic(t_sym, y_sym, symbolic_inputs)
115+
116+
# Set constraints vector in the casadi format
117+
# Constrain the unknowns. 0 (default): no constraint on ui, 1: ui >= 0.0,
118+
# -1: ui <= 0.0, 2: ui > 0.0, -2: ui < 0.0.
119+
constraints = np.zeros_like(model.bounds[0], dtype=int)
120+
# If the lower bound is positive then the variable must always be positive
121+
constraints[model.bounds[0] >= 0] = 1
122+
# If the upper bound is negative then the variable must always be negative
123+
constraints[model.bounds[1] <= 0] = -1
124+
125+
# Set up rootfinder
126+
roots = casadi.rootfinder(
127+
"roots",
128+
"newton",
129+
dict(x=y_alg_sym, p=t_and_inputs_sym, g=alg),
130+
{
131+
**self.extra_options,
132+
"abstol": self.tol,
133+
"constraints": list(constraints[len_rhs:]),
134+
},
135+
)
136+
137+
self.rootfinders[model] = roots
138+
132139
for idx, t in enumerate(t_eval):
133140
# Evaluate algebraic with new t and previous y0, if it's already close
134141
# enough then keep it
@@ -145,10 +152,15 @@ def _integrate(self, model, t_eval, inputs=None):
145152
y_alg = casadi.horzcat(y_alg, y0_alg)
146153
# Otherwise calculate new y_sol
147154
else:
148-
t_eval_inputs_sym = casadi.vertcat(t, symbolic_inputs)
155+
# If doing sensitivity with casadi, evaluate with symbolic inputs
156+
# Otherwise, evaluate with actual inputs
157+
if self.sensitivity == "casadi":
158+
t_eval_and_inputs = casadi.vertcat(t, symbolic_inputs)
159+
else:
160+
t_eval_and_inputs = casadi.vertcat(t, inputs)
149161
# Solve
150162
try:
151-
y_alg_sol = roots(y0_alg, t_eval_inputs_sym)
163+
y_alg_sol = roots(y0_alg, t_eval_and_inputs)
152164
success = True
153165
message = None
154166
# Check final output
@@ -193,6 +205,8 @@ def _integrate(self, model, t_eval, inputs=None):
193205
# If doing sensitivity, return the solution as a function of the inputs
194206
if self.sensitivity == "casadi":
195207
y_sol = casadi.Function("y_sol", [symbolic_inputs], [y_sol])
208+
# Save the solution, can just reuse and change the inputs
209+
self.y_sols[model] = y_sol
196210
# Return solution object (no events, so pass None to t_event, y_event)
197211
return pybamm.Solution(
198212
t_eval, y_sol, termination="success", model=model, inputs=inputs_dict

pybamm/solvers/casadi_solver.py

+40-21
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,12 @@ class CasadiSolver(pybamm.BaseSolver):
5959
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
6060
details.
6161
sensitivity : bool, optional
62-
Whether to explicitly formulate and solve the forward sensitivity equations.
63-
See :class:`pybamm.BaseSolver`
62+
Whether (and how) to calculate sensitivities when solving. Options are:
6463
64+
- None: no sensitivities
65+
- "explicit forward": explicitly formulate the sensitivity equations.
66+
See :class:`pybamm.BaseSolver`
67+
- "casadi": use casadi to differentiate through the integrator
6568
"""
6669

6770
def __init__(
@@ -104,6 +107,7 @@ def __init__(
104107
# Initialize
105108
self.integrators = {}
106109
self.integrator_specs = {}
110+
self.y_sols = {}
107111

108112
pybamm.citations.register("Andersson2019")
109113

@@ -122,24 +126,29 @@ def _integrate(self, model, t_eval, inputs=None):
122126
"""
123127
# Record whether there are any symbolic inputs
124128
inputs_dict = inputs or {}
125-
has_symbolic_inputs = any(
126-
isinstance(v, casadi.MX) for v in inputs_dict.values()
127-
)
128129

129130
# convert inputs to casadi format
130131
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])
131132

132-
if has_symbolic_inputs:
133-
# Create integrator without grid to avoid having to create several times
134-
self.create_integrator(model, inputs)
135-
solution = self._run_integrator(model, model.y0, inputs_dict, t_eval)
133+
if self.sensitivity == "casadi" and inputs_dict != {}:
134+
# If the solution has already been created, we can reuse it
135+
if model in self.y_sols:
136+
y_sol = self.y_sols[model]
137+
solution = pybamm.Solution(
138+
t_eval, y_sol, model=model, inputs=inputs_dict
139+
)
140+
else:
141+
# Create integrator without grid, which will be called repeatedly
142+
# This is necessary for casadi to compute sensitivities
143+
self.create_integrator(model, inputs_dict)
144+
solution = self._run_integrator(model, model.y0, inputs_dict, t_eval)
136145
solution.termination = "final time"
137146
return solution
138147
elif self.mode == "fast" or not model.events:
139148
if not model.events:
140149
pybamm.logger.info("No events found, running fast mode")
141150
# Create an integrator with the grid (we just need to do this once)
142-
self.create_integrator(model, inputs, t_eval)
151+
self.create_integrator(model, inputs_dict, t_eval)
143152
solution = self._run_integrator(model, model.y0, inputs_dict, t_eval)
144153
solution.termination = "final time"
145154
return solution
@@ -161,7 +170,7 @@ def _integrate(self, model, t_eval, inputs=None):
161170
# in "safe without grid" mode,
162171
# create integrator once, without grid,
163172
# to avoid having to create several times
164-
self.create_integrator(model, inputs)
173+
self.create_integrator(model, inputs_dict)
165174
# Initialize solution
166175
solution = pybamm.Solution(
167176
np.array([t]), y0[:, np.newaxis], model=model, inputs=inputs_dict
@@ -314,12 +323,15 @@ def event_fun(t):
314323
y0 = solution.y[:, -1]
315324
return solution
316325

317-
def create_integrator(self, model, inputs, t_eval=None):
326+
def create_integrator(self, model, inputs_dict, t_eval=None):
318327
"""
319328
Method to create a casadi integrator object.
320329
If t_eval is provided, the integrator uses t_eval to make the grid.
321330
Otherwise, the integrator has grid [0,1].
322331
"""
332+
# convert inputs to casadi format
333+
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])
334+
323335
# Use grid if t_eval is given
324336
use_grid = not (t_eval is None)
325337
# Only set up problem once
@@ -400,6 +412,13 @@ def create_integrator(self, model, inputs, t_eval=None):
400412

401413
def _run_integrator(self, model, y0, inputs_dict, t_eval):
402414
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])
415+
symbolic_inputs = casadi.MX.sym("inputs", inputs.shape[0])
416+
# If doing sensitivity with casadi, evaluate with symbolic inputs
417+
# Otherwise, evaluate with actual inputs
418+
if self.sensitivity == "casadi":
419+
inputs_eval = symbolic_inputs
420+
else:
421+
inputs_eval = inputs
403422
integrator, use_grid = self.integrators[model]
404423
# Split up initial conditions into differential and algebraic
405424
# Check y0 to see if it includes sensitivities
@@ -415,10 +434,9 @@ def _run_integrator(self, model, y0, inputs_dict, t_eval):
415434
if use_grid is True:
416435
# Call the integrator once, with the grid
417436
sol = integrator(
418-
x0=y0_diff, z0=y0_alg, p=inputs, **self.extra_options_call
437+
x0=y0_diff, z0=y0_alg, p=inputs_eval, **self.extra_options_call
419438
)
420439
y_sol = np.concatenate([sol["xf"].full(), sol["zf"].full()])
421-
return pybamm.Solution(t_eval, y_sol, model=model, inputs=inputs_dict)
422440
else:
423441
# Repeated calls to the integrator
424442
x = y0_diff
@@ -428,7 +446,7 @@ def _run_integrator(self, model, y0, inputs_dict, t_eval):
428446
for i in range(len(t_eval) - 1):
429447
t_min = t_eval[i]
430448
t_max = t_eval[i + 1]
431-
inputs_with_tlims = casadi.vertcat(inputs, t_min, t_max)
449+
inputs_with_tlims = casadi.vertcat(inputs_eval, t_min, t_max)
432450
sol = integrator(
433451
x0=x, z0=z, p=inputs_with_tlims, **self.extra_options_call
434452
)
@@ -438,14 +456,15 @@ def _run_integrator(self, model, y0, inputs_dict, t_eval):
438456
if not z.is_empty():
439457
y_alg = casadi.horzcat(y_alg, z)
440458
if z.is_empty():
441-
return pybamm.Solution(
442-
t_eval, y_diff, model=model, inputs=inputs_dict
443-
)
459+
y_sol = y_diff
444460
else:
445461
y_sol = casadi.vertcat(y_diff, y_alg)
446-
return pybamm.Solution(
447-
t_eval, y_sol, model=model, inputs=inputs_dict
448-
)
462+
# If doing sensitivity, return the solution as a function of the inputs
463+
if self.sensitivity == "casadi":
464+
y_sol = casadi.Function("y_sol", [symbolic_inputs], [y_sol])
465+
# Save the solution, can just reuse and change the inputs
466+
self.y_sols[model] = y_sol
467+
return pybamm.Solution(t_eval, y_sol, model=model, inputs=inputs_dict)
449468
except RuntimeError as e:
450469
# If it doesn't work raise error
451470
raise pybamm.SolverError(e.args[0])

pybamm/solvers/processed_variable.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -722,11 +722,7 @@ def initialise_1D_symbolic():
722722
+ "implemented)"
723723
)
724724

725-
# Make entries a function and compute jacobian
726-
casadi_entries_fn = casadi.Function(
727-
"variable", [self.symbolic_inputs], [entries_MX]
728-
)
729-
725+
# Compute jacobian
730726
sens_MX = casadi.jacobian(entries_MX, self.symbolic_inputs)
731727
casadi_sens_fn = casadi.Function("variable", [self.symbolic_inputs], [sens_MX])
732728

0 commit comments

Comments
 (0)