Skip to content

Commit 54d45f9

Browse files
#1100 get SODEs working for casadi solver
1 parent f97c952 commit 54d45f9

File tree

6 files changed

+336
-29
lines changed

6 files changed

+336
-29
lines changed

pybamm/solvers/base_solver.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import sys
1010
import itertools
11+
from scipy.linalg import block_diag
1112

1213

1314
class BaseSolver(object):
@@ -426,12 +427,21 @@ def report(string):
426427
):
427428
# can use DAE solver to solve model with algebraic equations only
428429
if len(model.rhs) > 0:
429-
mass_matrix_inv = casadi.MX(model.mass_matrix_inv.entries)
430+
if self.solve_sensitivity_equations is True:
431+
# Copy mass matrix blocks diagonally
432+
single_mass_matrix_inv = model.mass_matrix_inv.entries.toarray()
433+
n_inputs = p_casadi_stacked.shape[0]
434+
block_mass_matrix = block_diag(
435+
*[single_mass_matrix_inv] * (n_inputs + 1)
436+
)
437+
mass_matrix_inv = casadi.MX(block_mass_matrix)
438+
else:
439+
mass_matrix_inv = casadi.MX(model.mass_matrix_inv.entries)
430440
explicit_rhs = mass_matrix_inv @ rhs(
431-
t_casadi, y_casadi, p_casadi_stacked
441+
t_casadi, y_and_S, p_casadi_stacked
432442
)
433443
model.casadi_rhs = casadi.Function(
434-
"rhs", [t_casadi, y_casadi, p_casadi_stacked], [explicit_rhs]
444+
"rhs", [t_casadi, y_and_S, p_casadi_stacked], [explicit_rhs]
435445
)
436446
model.casadi_algebraic = algebraic
437447
if len(model.rhs) == 0:
@@ -703,10 +713,6 @@ def solve(self, model, t_eval=None, external_variables=None, inputs=None):
703713
solution.set_up_time = set_up_time
704714
solution.solve_time = timer.time()
705715

706-
# Add model and inputs to solution
707-
solution.model = model
708-
solution.inputs = ext_and_inputs
709-
710716
# Identify the event that caused termination
711717
termination = self.get_termination_reason(solution, model.events)
712718

pybamm/solvers/casadi_solver.py

+44-15
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class CasadiSolver(pybamm.BaseSolver):
5555
Any options to pass to the CasADi integrator when calling the integrator.
5656
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
5757
details.
58+
solve_sensitivity_equations : bool, optional
59+
Whether to explicitly formulate and solve the forward sensitivity equations.
60+
See :class:`pybamm.BaseSolver`
5861
5962
"""
6063

@@ -69,8 +72,16 @@ def __init__(
6972
dt_max=None,
7073
extra_options_setup=None,
7174
extra_options_call=None,
75+
solve_sensitivity_equations=False,
7276
):
73-
super().__init__("problem dependent", rtol, atol, root_method, root_tol)
77+
super().__init__(
78+
"problem dependent",
79+
rtol,
80+
atol,
81+
root_method,
82+
root_tol,
83+
solve_sensitivity_equations=solve_sensitivity_equations,
84+
)
7485
if mode in ["safe", "fast"]:
7586
self.mode = mode
7687
else:
@@ -106,24 +117,26 @@ def _integrate(self, model, t_eval, inputs=None):
106117
Any external variables or input parameters to pass to the model when solving
107118
"""
108119
# Record whether there are any symbolic inputs
109-
inputs = inputs or {}
110-
has_symbolic_inputs = any(isinstance(v, casadi.MX) for v in inputs.values())
120+
inputs_dict = inputs or {}
121+
has_symbolic_inputs = any(
122+
isinstance(v, casadi.MX) for v in inputs_dict.values()
123+
)
111124

112125
# convert inputs to casadi format
113-
inputs = casadi.vertcat(*[x for x in inputs.values()])
126+
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])
114127

115128
if has_symbolic_inputs:
116129
# Create integrax`tor without grid to avoid having to create several times
117130
self.get_integrator(model, inputs)
118-
solution = self._run_integrator(model, model.y0, inputs, t_eval)
131+
solution = self._run_integrator(model, model.y0, inputs_dict, t_eval)
119132
solution.termination = "final time"
120133
return solution
121134
elif self.mode == "fast" or not model.events:
122135
if not model.events:
123136
pybamm.logger.info("No events found, running fast mode")
124137
# Create an integrator with the grid (we just need to do this once)
125138
self.get_integrator(model, inputs, t_eval)
126-
solution = self._run_integrator(model, model.y0, inputs, t_eval)
139+
solution = self._run_integrator(model, model.y0, inputs_dict, t_eval)
127140
solution.termination = "final time"
128141
return solution
129142
elif self.mode == "safe":
@@ -143,7 +156,9 @@ def _integrate(self, model, t_eval, inputs=None):
143156
pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))
144157

145158
# Initialize solution
146-
solution = pybamm.Solution(np.array([t]), y0[:, np.newaxis])
159+
solution = pybamm.Solution(
160+
np.array([t]), y0[:, np.newaxis], model=model, inputs=inputs_dict
161+
)
147162
solution.solve_time = 0
148163

149164
# Try to integrate in global steps of size dt_max. Note: dt_max must
@@ -178,7 +193,7 @@ def _integrate(self, model, t_eval, inputs=None):
178193
# halve the step size and try again.
179194
try:
180195
current_step_sol = self._run_integrator(
181-
model, y0, inputs, t_window
196+
model, y0, inputs_dict, t_window
182197
)
183198
solved = True
184199
except pybamm.SolverError:
@@ -257,7 +272,9 @@ def event_fun(t):
257272
t_window = np.array([t, t_event])
258273

259274
# integrator = self.get_integrator(model, t_window, inputs)
260-
current_step_sol = self._run_integrator(model, y0, inputs, t_window)
275+
current_step_sol = self._run_integrator(
276+
model, y0, inputs_dict, t_window
277+
)
261278

262279
# assign temporary solve time
263280
current_step_sol.solve_time = np.nan
@@ -361,10 +378,18 @@ def get_integrator(self, model, inputs, t_eval=None):
361378
self.integrators[model] = (integrator, use_grid)
362379
return integrator
363380

364-
def _run_integrator(self, model, y0, inputs, t_eval):
381+
def _run_integrator(self, model, y0, inputs_dict, t_eval):
382+
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])
365383
integrator, use_grid = self.integrators[model]
366-
y0_diff = y0[: model.len_rhs]
367-
y0_alg = y0[model.len_rhs :]
384+
# Split up initial conditions into differential and algebraic
385+
# Check y0 to see if it includes sensitivities
386+
if model.len_rhs_and_alg == y0.shape[0]:
387+
len_rhs = model.len_rhs
388+
else:
389+
len_rhs = model.len_rhs * (inputs.shape[0] + 1)
390+
y0_diff = y0[:len_rhs]
391+
y0_alg = y0[len_rhs:]
392+
# Solve
368393
try:
369394
# Try solving
370395
if use_grid is True:
@@ -379,7 +404,7 @@ def _run_integrator(self, model, y0, inputs, t_eval):
379404
**self.extra_options_call
380405
)
381406
y_sol = np.concatenate([sol["xf"].full(), sol["zf"].full()])
382-
return pybamm.Solution(t_eval, y_sol)
407+
return pybamm.Solution(t_eval, y_sol, model=model, inputs=inputs_dict)
383408
else:
384409
# Repeated calls to the integrator
385410
x = y0_diff
@@ -399,10 +424,14 @@ def _run_integrator(self, model, y0, inputs, t_eval):
399424
if not z.is_empty():
400425
y_alg = casadi.horzcat(y_alg, z)
401426
if z.is_empty():
402-
return pybamm.Solution(t_eval, y_diff)
427+
return pybamm.Solution(
428+
t_eval, y_diff, model=model, inputs=inputs_dict
429+
)
403430
else:
404431
y_sol = casadi.vertcat(y_diff, y_alg)
405-
return pybamm.Solution(t_eval, y_sol)
432+
return pybamm.Solution(
433+
t_eval, y_sol, model=model, inputs=inputs_dict
434+
)
406435
except RuntimeError as e:
407436
# If it doesn't work raise error
408437
raise pybamm.SolverError(e.args[0])

pybamm/solvers/scipy_solver.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ScipySolver(pybamm.BaseSolver):
2424
Please consult `SciPy documentation <https://tinyurl.com/yafgqg9y>`_ for
2525
details.
2626
solve_sensitivity_equations : bool, optional
27+
Whether to explicitly formulate and solve the forward sensitivity equations.
2728
See :class:`pybamm.BaseSolver`
2829
"""
2930

pybamm/solvers/solution.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
# y only has the shape or the rhs and alg solution)
6767
if model is None or model.len_rhs_and_alg == y.shape[0]:
6868
self._y = y
69+
self.sensitivity = {}
6970
else:
7071
n_states = model.len_rhs_and_alg
7172
n_t = len(t)
@@ -133,11 +134,9 @@ def __init__(
133134
if copy_this is None:
134135
self.set_up_time = None
135136
self.solve_time = None
136-
self.has_symbolic_inputs = False
137137
else:
138138
self.set_up_time = copy_this.set_up_time
139139
self.solve_time = copy_this.solve_time
140-
self.has_symbolic_inputs = copy_this.has_symbolic_inputs
141140

142141
# initiaize empty variables and data
143142
self._variables = pybamm.FuzzyDict()

0 commit comments

Comments
 (0)