Skip to content

Commit 5214994

Browse files
committed
#1477 generalising 'explicit forward' option so any solver can use it
1 parent 0cbe0a5 commit 5214994

File tree

3 files changed

+57
-27
lines changed

3 files changed

+57
-27
lines changed

pybamm/solvers/base_solver.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
import sys
1010
import itertools
11-
from scipy.linalg import block_diag
11+
from scipy.sparse import block_diag
1212
import multiprocessing as mp
1313
import warnings
1414

@@ -241,6 +241,8 @@ def set_up(self, model, inputs=None, t_eval=None,
241241
# save sensitivity parameters so we can identify them later on
242242
# (FYI: this is used in the Solution class)
243243
model.calculate_sensitivities = calculate_sensitivites
244+
model.len_rhs_sens = model.len_rhs * len(calculate_sensitivites)
245+
model.len_alg_sens = model.len_alg * len(calculate_sensitivites)
244246

245247
# Only allow solving explicit sensitivity equations with the casadi format for now
246248
if (
@@ -277,8 +279,6 @@ def set_up(self, model, inputs=None, t_eval=None,
277279
pS_casadi_stacked = casadi.vertcat(
278280
*[p_casadi[name] for name in calculate_sensitivites]
279281
)
280-
model.len_rhs_sens = model.len_rhs * pS_casadi_stacked.shape[0]
281-
model.len_alg_sens = model.len_alg * pS_casadi_stacked.shape[0]
282282
S_x = casadi.MX.sym("S_x", model.len_rhs_sens)
283283
S_z = casadi.MX.sym("S_z", model.len_alg_sens)
284284
y_and_S = casadi.vertcat(y_diff, S_x, y_alg, S_z)
@@ -615,6 +615,21 @@ def jacp(*args, **kwargs):
615615
interpolant_extrapolation_events_eval
616616
)
617617

618+
# if we have changed the equations to include the explicit sensitivity
619+
# equations, then we also need to update the mass matrix
620+
if self.sensitivity == "explicit forward":
621+
n_inputs = len(calculate_sensitivites)
622+
model.mass_matrix_inv = pybamm.Matrix(
623+
block_diag(
624+
[model.mass_matrix_inv.entries] * (n_inputs + 1), format="csr"
625+
)
626+
)
627+
model.mass_matrix = pybamm.Matrix(
628+
block_diag(
629+
[model.mass_matrix.entries] * (n_inputs + 1), format="csr"
630+
)
631+
)
632+
618633
# Save CasADi functions for the CasADi solver
619634
# Note: when we pass to casadi the ode part of the problem must be in explicit
620635
# form so we pre-multiply by the inverse of the mass matrix
@@ -623,16 +638,7 @@ def jacp(*args, **kwargs):
623638
):
624639
# can use DAE solver to solve model with algebraic equations only
625640
if len(model.rhs) > 0:
626-
if self.sensitivity == "explicit forward":
627-
# Copy mass matrix blocks diagonally
628-
single_mass_matrix_inv = model.mass_matrix_inv.entries.toarray()
629-
n_inputs = p_casadi_stacked.shape[0]
630-
block_mass_matrix = block_diag(
631-
*[single_mass_matrix_inv] * (n_inputs + 1)
632-
)
633-
mass_matrix_inv = casadi.MX(block_mass_matrix)
634-
else:
635-
mass_matrix_inv = casadi.MX(model.mass_matrix_inv.entries)
641+
mass_matrix_inv = casadi.MX(model.mass_matrix_inv.entries)
636642
explicit_rhs = mass_matrix_inv @ rhs(
637643
t_casadi, y_and_S, p_casadi_stacked
638644
)
@@ -754,8 +760,7 @@ def calculate_consistent_state(self, model, time=0, inputs=None):
754760
)
755761
pybamm.logger.debug("Found consistent states")
756762

757-
# use all_ys_and_sens in case we are solving the full sensitivity equations
758-
y0 = root_sol.all_ys_and_sens[0]
763+
y0 = root_sol.all_ys[0]
759764
if isinstance(y0, np.ndarray):
760765
y0 = y0.flatten()
761766
return y0

pybamm/solvers/casadi_solver.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ def _integrate(self, model, t_eval, inputs_dict=None):
131131
inputs_dict : dict, optional
132132
Any external variables or input parameters to pass to the model when solving
133133
"""
134+
135+
136+
# are we solving explicit forward equations?
137+
explicit_sensitivities = self.sensitivity == 'explicit forward'
138+
134139
# Record whether there are any symbolic inputs
135140
inputs_dict = inputs_dict or {}
136141

@@ -158,14 +163,15 @@ def _integrate(self, model, t_eval, inputs_dict=None):
158163
# Create integrator without grid to avoid having to create several times
159164
self.create_integrator(model, inputs)
160165
solution = self._run_integrator(
161-
model, model.y0, inputs_dict, inputs, t_eval, use_grid=False
166+
model, model.y0, inputs_dict, inputs, t_eval, use_grid=False,
162167
)
163168
if self.sensitivity == "casadi" and inputs_dict != {}:
164169
# If the solution has already been created, we can reuse it
165170
if model in self.y_sols:
166171
y_sol = self.y_sols[model]
167172
solution = pybamm.Solution(
168-
t_eval, y_sol, model=model, inputs=inputs_dict
173+
t_eval, y_sol, model=model, inputs=inputs_dict,
174+
sensitivities=explicit_sensitivities
169175
)
170176
else:
171177
# Create integrator without grid, which will be called repeatedly
@@ -212,7 +218,10 @@ def _integrate(self, model, t_eval, inputs_dict=None):
212218
# to avoid having to create several times
213219
self.create_integrator(model, inputs_dict)
214220
# Initialize solution
215-
solution = pybamm.Solution(np.array([t]), y0, model, inputs_dict)
221+
solution = pybamm.Solution(
222+
np.array([t]), y0, model, inputs_dict,
223+
sensitivities=explicit_sensitivities
224+
)
216225
solution.solve_time = 0
217226
solution.integration_time = 0
218227
use_grid = False
@@ -455,6 +464,7 @@ def integer_bisect():
455464
np.array([t_event]),
456465
y_event[:, np.newaxis],
457466
"event",
467+
sensitivities=explicit_sensitivities
458468
)
459469
solution.integration_time = (
460470
coarse_solution.integration_time + dense_step_sol.integration_time
@@ -613,6 +623,10 @@ def create_integrator(self, model, inputs, t_eval=None, use_event_switch=False):
613623

614624
def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True):
615625
pybamm.logger.debug("Running CasADi integrator")
626+
627+
# are we solving explicit forward equations?
628+
explicit_sensitivities = self.sensitivity == 'explicit forward'
629+
616630
if use_grid is True:
617631
t_eval_shifted = t_eval - t_eval[0]
618632
t_eval_shifted_rounded = np.round(t_eval_shifted, decimals=12).tobytes()
@@ -649,7 +663,10 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
649663
)
650664
integration_time = timer.time()
651665
y_sol = casadi.vertcat(casadi_sol["xf"], casadi_sol["zf"])
652-
sol = pybamm.Solution(t_eval, y_sol, model, inputs_dict)
666+
sol = pybamm.Solution(
667+
t_eval, y_sol, model, inputs_dict,
668+
sensitivities=explicit_sensitivities
669+
)
653670
sol.integration_time = integration_time
654671
return sol
655672
else:
@@ -682,7 +699,10 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
682699
# Save the solution, can just reuse and change the inputs
683700
self.y_sols[model] = y_sol
684701

685-
sol = pybamm.Solution(t_eval, y_sol, model, inputs_dict)
702+
sol = pybamm.Solution(
703+
t_eval, y_sol, model, inputs_dict,
704+
sensitivities=explicit_sensitivities
705+
)
686706
sol.integration_time = integration_time
687707
return sol
688708
except RuntimeError as e:

pybamm/solvers/solution.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ class Solution(object):
4141
the event happens.
4242
termination : str
4343
String to indicate why the solution terminated
44-
sensitivities: None or dict
45-
Will be None if there are no sensitivities in this soluion. Otherwise, this is a
46-
dict of parameter names to their calcululated sensitivities
44+
45+
sensitivities: bool or dict
46+
True if sensitivities included as the solution of the explicit forwards
47+
equations. False if no sensitivities included/wanted. Dict if sensitivities are
48+
provided as a dict of {parameter: sensitivities} pairs.
4749
4850
"""
4951

@@ -56,7 +58,7 @@ def __init__(
5658
t_event=None,
5759
y_event=None,
5860
termination="final time",
59-
sensitivities=None
61+
sensitivities=False
6062
):
6163
if not isinstance(all_ts, list):
6264
all_ts = [all_ts]
@@ -80,11 +82,12 @@ def __init__(
8082
self.all_inputs = all_inputs
8183

8284
# sensitivities
83-
if sensitivities is None:
85+
if isinstance(sensitivities, bool):
8486
self._sensitivities = {}
8587
# if solution consists of explicit sensitivity equations, extract them
8688
if (
87-
all_models[0] is not None
89+
sensitivities == True
90+
and all_models[0] is not None
8891
and not isinstance(all_ys[0], casadi.Function)
8992
and all_models[0].len_rhs_and_alg != all_ys[0].shape[0]
9093
and all_models[0].len_rhs_and_alg != 0 # for the dummy solver
@@ -95,8 +98,10 @@ def __init__(
9598
self._extract_explicit_sensitivities(
9699
all_models[0], all_ys[0], all_ts[0], self.all_inputs[0]
97100
)
98-
else:
101+
elif isinstance(sensitivities, dict):
99102
self._sensitivities = sensitivities
103+
else:
104+
raise RuntimeError('sensitivities arg needs to be a bool or dict')
100105

101106
self._t_event = t_event
102107
self._y_event = y_event

0 commit comments

Comments
 (0)