Skip to content

Commit 0cfdf6f

Browse files
#1100 reformatted sensitivity API
1 parent 27475b8 commit 0cfdf6f

8 files changed

+296
-314
lines changed

pybamm/solvers/base_solver.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,15 @@ class BaseSolver(object):
3131
specified by 'root_method' (e.g. "lm", "hybr", ...)
3232
root_tol : float, optional
3333
The tolerance for the initial-condition solver (default is 1e-6).
34-
sensitivity : bool, optional
35-
Whether to explicitly formulate the sensitivity equations for sensitivity
36-
to input parameters. The formulation is as per "Park, S., Kato, D., Gima, Z.,
34+
sensitivity : str, optional
35+
Whether (and how) to calculate sensitivities when solving. Options are:
36+
37+
- "explicit forward": explicitly formulate the sensitivity equations.
38+
The formulation is as per "Park, S., Kato, D., Gima, Z.,
3739
Klein, R., & Moura, S. (2018). Optimal experimental design for parameterization
3840
of an electrochemical lithium-ion battery model. Journal of The Electrochemical
3941
Society, 165(7), A1309.". See #1100 for details
42+
- see specific solvers for other options
4043
"""
4144

4245
def __init__(
@@ -47,7 +50,7 @@ def __init__(
4750
root_method=None,
4851
root_tol=1e-6,
4952
max_steps="deprecated",
50-
sensitivity=False,
53+
sensitivity=None,
5154
):
5255
self._method = method
5356
self._rtol = rtol

pybamm/solvers/casadi_algebraic_solver.py

+38-13
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
import casadi
55
import pybamm
6+
import numbers
67
import numpy as np
78

89

@@ -21,10 +22,18 @@ class CasadiAlgebraicSolver(pybamm.BaseSolver):
2122
Any options to pass to the CasADi rootfinder.
2223
Please consult `CasADi documentation <https://tinyurl.com/y7hrxm7d>`_ for
2324
details.
25+
sensitivity : str, optional
26+
Whether (and how) to calculate sensitivities when solving. Options are:
27+
28+
- None: no sensitivities
29+
- "explicit forward": explicitly formulate the sensitivity equations.
30+
See :class:`pybamm.BaseSolver`
31+
- "casadi": use casadi to differentiate through the rootfinding operator
32+
2433
"""
2534

26-
def __init__(self, tol=1e-6, extra_options=None):
27-
super().__init__()
35+
def __init__(self, tol=1e-6, extra_options=None, sensitivity=None):
36+
super().__init__(sensitivity=sensitivity)
2837
self.tol = tol
2938
self.name = "CasADi algebraic solver"
3039
self.algebraic_solver = True
@@ -57,14 +66,24 @@ def _integrate(self, model, t_eval, inputs=None):
5766
"""
5867
# Record whether there are any symbolic inputs
5968
inputs = inputs or {}
60-
has_symbolic_inputs = any(isinstance(v, casadi.MX) for v in inputs.values())
61-
symbolic_inputs = casadi.vertcat(
62-
*[v for v in inputs.values() if isinstance(v, casadi.MX)]
63-
)
64-
6569
# Create casadi objects for the root-finder
70+
inputs_dict = inputs
6671
inputs = casadi.vertcat(*[v for v in inputs.values()])
6772

73+
if self.sensitivity == "casadi" and inputs_dict != {}:
74+
# Create symbolic inputs for sensitivity analysis
75+
symbolic_inputs_list = []
76+
for name, value in inputs_dict.items():
77+
if isinstance(value, numbers.Number):
78+
symbolic_inputs_list.append(casadi.MX.sym(name))
79+
else:
80+
symbolic_inputs_list.append(casadi.MX.sym(name, value.shape[0]))
81+
symbolic_inputs = casadi.vertcat(*[p for p in symbolic_inputs_list])
82+
inputs_for_alg = symbolic_inputs
83+
else:
84+
symbolic_inputs = casadi.DM()
85+
inputs_for_alg = inputs
86+
6887
y0 = model.y0
6988
# The casadi algebraic solver can read rhs equations, but leaves them unchanged
7089
# i.e. the part of the solution vector that corresponds to the differential
@@ -91,7 +110,7 @@ def _integrate(self, model, t_eval, inputs=None):
91110
y_sym = casadi.vertcat(y0_diff, y_alg_sym)
92111

93112
t_and_inputs_sym = casadi.vertcat(t_sym, symbolic_inputs)
94-
alg = model.casadi_algebraic(t_sym, y_sym, inputs)
113+
alg = model.casadi_algebraic(t_sym, y_sym, inputs_for_alg)
95114

96115
# Set constraints vector in the casadi format
97116
# Constrain the unknowns. 0 (default): no constraint on ui, 1: ui >= 0.0,
@@ -116,8 +135,8 @@ def _integrate(self, model, t_eval, inputs=None):
116135
for idx, t in enumerate(t_eval):
117136
# Evaluate algebraic with new t and previous y0, if it's already close
118137
# enough then keep it
119-
# We can't do this if there are symbolic inputs
120-
if has_symbolic_inputs is False and np.all(
138+
# We can't do this if also doing sensitivity
139+
if self.sensitivity != "casadi" and np.all(
121140
abs(model.casadi_algebraic(t, y0, inputs).full()) < self.tol
122141
):
123142
pybamm.logger.debug(
@@ -144,9 +163,9 @@ def _integrate(self, model, t_eval, inputs=None):
144163
fun = None
145164

146165
# If there are no symbolic inputs, check the function is below the tol
147-
# Skip this check if there are symbolic inputs
166+
# Skip this check if also doing sensitivity
148167
if success and (
149-
has_symbolic_inputs is True or np.all(casadi.fabs(fun) < self.tol)
168+
self.sensitivity == "casadi" or np.all(casadi.fabs(fun) < self.tol)
150169
):
151170
# update initial guess for the next iteration
152171
y0_alg = y_alg_sol
@@ -173,5 +192,11 @@ def _integrate(self, model, t_eval, inputs=None):
173192
# Concatenate differential part
174193
y_diff = casadi.horzcat(*[y0_diff] * len(t_eval))
175194
y_sol = casadi.vertcat(y_diff, y_alg)
195+
196+
# If doing sensitivity, return the solution as a function of the inputs
197+
if self.sensitivity == "casadi":
198+
y_sol = casadi.Function("y_sol", [symbolic_inputs], [y_sol])
176199
# Return solution object (no events, so pass None to t_event, y_event)
177-
return pybamm.Solution(t_eval, y_sol, termination="success")
200+
return pybamm.Solution(
201+
t_eval, y_sol, termination="success", model=model, inputs=inputs_dict
202+
)

pybamm/solvers/casadi_solver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(
7575
dt_max=None,
7676
extra_options_setup=None,
7777
extra_options_call=None,
78-
sensitivity=False,
78+
sensitivity=None,
7979
):
8080
super().__init__(
8181
"problem dependent",

pybamm/solvers/processed_variable.py

+121-4
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,40 @@ def __init__(self, base_variable, solution, known_evals=None, warn=True):
6464
self._sensitivity = None
6565
self.solution_sensitivity = solution.sensitivity
6666

67+
# Special case: symbolic solution, with casadi
68+
if isinstance(solution.y, casadi.Function):
69+
# Evaluate solution at specific inputs value
70+
inputs_stacked = casadi.vertcat(*solution.inputs.values())
71+
self.u_sol = solution.y(inputs_stacked).full()
72+
# Convert variable to casadi
73+
t_MX = casadi.MX.sym("t")
74+
y_MX = casadi.MX.sym("y", self.u_sol.shape[0])
75+
# Make all inputs symbolic first for converting to casadi
76+
symbolic_inputs_dict = {
77+
name: casadi.MX.sym(name, value.shape[0])
78+
for name, value in solution.inputs.items()
79+
}
80+
81+
# The symbolic_inputs will be used for sensitivity
82+
symbolic_inputs = casadi.vertcat(*symbolic_inputs_dict.values())
83+
try:
84+
var_casadi = base_variable.to_casadi(
85+
t_MX, y_MX, inputs=symbolic_inputs_dict
86+
)
87+
except:
88+
n = 1
89+
self.base_variable_sym = casadi.Function(
90+
"variable", [t_MX, y_MX, symbolic_inputs], [var_casadi]
91+
)
92+
# Store symbolic inputs for sensitivity
93+
self.symbolic_inputs = symbolic_inputs
94+
self.y_sym = solution.y(symbolic_inputs)
95+
else:
96+
self.u_sol = solution.y
97+
self.base_variable_sym = None
98+
self.symbolic_inputs = None
99+
self.y_sym = None
100+
67101
# Set timescale
68102
self.timescale = solution.model.timescale.evaluate()
69103
self.t_pts = self.t_sol * self.timescale
@@ -78,8 +112,8 @@ def __init__(self, base_variable, solution, known_evals=None, warn=True):
78112
# Evaluate base variable at initial time
79113
if self.known_evals:
80114
self.base_eval, self.known_evals[solution.t[0]] = base_variable.evaluate(
81-
solution.t[0],
82-
solution.y[:, 0],
115+
self.t_sol[0],
116+
self.u_sol[:, 0],
83117
inputs={name: inp[:, 0] for name, inp in solution.inputs.items()},
84118
known_evals=self.known_evals[solution.t[0]],
85119
)
@@ -571,10 +605,20 @@ def sensitivity(self):
571605
return {}
572606
# Otherwise initialise and return sensitivity
573607
if self._sensitivity is None:
574-
self.initialise_sensitivity()
608+
# Check that we can compute sensitivities
609+
if self.base_variable_sym is None and self.solution_sensitivity == {}:
610+
raise ValueError(
611+
"Cannot compute sensitivities. The 'sensitivity' argument of the "
612+
"solver should be changed from 'None' to allow sensitivity "
613+
"calculations. Check solver documentation for details."
614+
)
615+
if self.base_variable_sym is None:
616+
self.initialise_sensitivity_explicit_forward()
617+
else:
618+
self.initialise_sensitivity_casadi()
575619
return self._sensitivity
576620

577-
def initialise_sensitivity(self):
621+
def initialise_sensitivity_explicit_forward(self):
578622
"Set up the sensitivity dictionary"
579623
inputs_stacked = casadi.vertcat(*[p for p in self.inputs.values()])
580624

@@ -628,6 +672,79 @@ def initialise_sensitivity(self):
628672
# Save attribute
629673
self._sensitivity = sensitivity
630674

675+
def initialise_sensitivity_casadi(self):
676+
def initialise_0D_symbolic():
677+
"Create a 0D symbolic variable"
678+
# Evaluate the base_variable index-by-index
679+
for idx in range(len(self.t_sol)):
680+
t = self.t_sol[idx]
681+
u = self.y_sym[:, idx]
682+
next_entries = self.base_variable_sym(t, u, self.symbolic_inputs)
683+
if idx == 0:
684+
entries = next_entries
685+
else:
686+
entries = casadi.horzcat(entries, next_entries)
687+
688+
return entries
689+
690+
def initialise_1D_symbolic():
691+
"Create a 1D symbolic variable"
692+
# Evaluate the base_variable index-by-index
693+
for idx in range(len(self.t_sol)):
694+
t = self.t_sol[idx]
695+
u = self.y_sym[:, idx]
696+
next_entries = self.base_variable_sym(t, u, self.symbolic_inputs)
697+
if idx == 0:
698+
entries = next_entries
699+
else:
700+
entries = casadi.vertcat(entries, next_entries)
701+
702+
return entries
703+
704+
inputs_stacked = casadi.vertcat(*self.inputs.values())
705+
self.base_eval = self.base_variable_sym(
706+
self.t_sol[0], self.u_sol[:, 0], inputs_stacked
707+
)
708+
if (
709+
isinstance(self.base_eval, numbers.Number)
710+
or len(self.base_eval.shape) == 0
711+
or self.base_eval.shape[0] == 1
712+
):
713+
entries_MX = initialise_0D_symbolic()
714+
else:
715+
n = self.mesh.npts
716+
base_shape = self.base_eval.shape[0]
717+
# Try shape that could make the variable a 1D variable
718+
if base_shape == n:
719+
entries_MX = initialise_1D_symbolic()
720+
else:
721+
# Raise error for 2D variable
722+
raise NotImplementedError(
723+
"Shape not recognized for {} ".format(self.base_variable)
724+
+ "(note processing of 2D and 3D variables is not yet "
725+
+ "implemented)"
726+
)
727+
728+
# Make entries a function and compute jacobian
729+
casadi_entries_fn = casadi.Function(
730+
"variable", [self.symbolic_inputs], [entries_MX]
731+
)
732+
733+
sens_MX = casadi.jacobian(entries_MX, self.symbolic_inputs)
734+
casadi_sens_fn = casadi.Function("variable", [self.symbolic_inputs], [sens_MX])
735+
736+
sens_eval = casadi_sens_fn(inputs_stacked)
737+
sensitivity = {"all": sens_eval}
738+
739+
# Add the individual sensitivity
740+
start = 0
741+
for name, inp in self.inputs.items():
742+
end = start + inp.shape[0]
743+
sensitivity[name] = sens_eval[:, start:end]
744+
start = end
745+
746+
self._sensitivity = sensitivity
747+
631748

632749
def eval_dimension_name(name, x, r, y, z):
633750
if name == "x":

pybamm/solvers/scipy_solver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class ScipySolver(pybamm.BaseSolver):
2929
"""
3030

3131
def __init__(
32-
self, method="BDF", rtol=1e-6, atol=1e-6, extra_options=None, sensitivity=False,
32+
self, method="BDF", rtol=1e-6, atol=1e-6, extra_options=None, sensitivity=None,
3333
):
3434
super().__init__(
3535
method=method, rtol=rtol, atol=atol, sensitivity=sensitivity,

pybamm/solvers/solution.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,13 @@ def __init__(
6464
# If the model has been provided, split up y into solution and sensitivity
6565
# Don't do this if the sensitivity equations have not been computed (i.e. if
6666
# y only has the shape or the rhs and alg solution)
67-
if model is None or model.len_rhs_and_alg == y.shape[0]:
67+
# Don't do this if y is symbolic (sensitivities will be calculated a different
68+
# way)
69+
if (
70+
model is None
71+
or isinstance(y, casadi.Function)
72+
or model.len_rhs_and_alg == y.shape[0]
73+
):
6874
self._y = y
6975
self.sensitivity = {}
7076
else:
@@ -243,7 +249,6 @@ def update(self, variables):
243249
var = pybamm.ProcessedVariable(
244250
self.model.variables[key], self, self._known_evals
245251
)
246-
247252
# Update known_evals in order to process any other variables faster
248253
for t in var.known_evals:
249254
self._known_evals[t].update(var.known_evals[t])

0 commit comments

Comments
 (0)