Skip to content

Commit a59abb5

Browse files
#1100 starting to add SDAEs (odes only for now)
1 parent 29d617a commit a59abb5

8 files changed

+366
-31
lines changed

pybamm/discretisations/discretisation.py

+5
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ def process_model(self, model, inplace=True, check_model=True):
205205
model_disc.rhs, model_disc.concatenated_rhs = rhs, concat_rhs
206206
model_disc.algebraic, model_disc.concatenated_algebraic = alg, concat_alg
207207

208+
# Save length of rhs and algebraic
209+
model_disc.len_rhs = model_disc.concatenated_rhs.size
210+
model_disc.len_alg = model_disc.concatenated_algebraic.size
211+
model_disc.len_rhs_and_alg = model_disc.len_rhs + model_disc.len_alg
212+
208213
# Process events
209214
processed_events = []
210215
pybamm.logger.info("Discretise events for {}".format(model.name))

pybamm/solvers/base_solver.py

+71-11
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ class BaseSolver(object):
3030
specified by 'root_method' (e.g. "lm", "hybr", ...)
3131
root_tol : float, optional
3232
The tolerance for the initial-condition solver (default is 1e-6).
33+
solve_sensitivity_equations : bool, optional
34+
Whether to explicitly formulate the sensitivity equations for sensitivity
35+
to input parameters. The formulation is as per "Park, S., Kato, D., Gima, Z.,
36+
Klein, R., & Moura, S. (2018). Optimal experimental design for parameterization
37+
of an electrochemical lithium-ion battery model. Journal of The Electrochemical
38+
Society, 165(7), A1309.". See #1100 for details
3339
"""
3440

3541
def __init__(
@@ -40,6 +46,7 @@ def __init__(
4046
root_method=None,
4147
root_tol=1e-6,
4248
max_steps="deprecated",
49+
solve_sensitivity_equations=False,
4350
):
4451
self._method = method
4552
self._rtol = rtol
@@ -57,6 +64,7 @@ def __init__(
5764
self.name = "Base solver"
5865
self.ode_solver = False
5966
self.algebraic_solver = False
67+
self.solve_sensitivity_equations = solve_sensitivity_equations
6068

6169
@property
6270
def method(self):
@@ -191,17 +199,28 @@ def set_up(self, model, inputs=None):
191199
)
192200
model.convert_to_format = "casadi"
193201

202+
# Only allow solving sensitivity equations with the casadi format for now
203+
if (
204+
self.solve_sensitivity_equations is True
205+
and model.convert_to_format != "casadi"
206+
):
207+
raise NotImplementedError(
208+
"model should be converted to casadi format in order to solve "
209+
"sensitivity equations"
210+
)
211+
194212
if model.convert_to_format != "casadi":
195213
simp = pybamm.Simplification()
196214
# Create Jacobian from concatenated rhs and algebraic
197-
y = pybamm.StateVector(slice(0, model.concatenated_initial_conditions.size))
215+
y = pybamm.StateVector(slice(0, model.len_rhs_and_alg))
198216
# set up Jacobian object, for re-use of dict
199217
jacobian = pybamm.Jacobian()
200218
else:
201219
# Convert model attributes to casadi
202220
t_casadi = casadi.MX.sym("t")
203-
y_diff = casadi.MX.sym("y_diff", model.concatenated_rhs.size)
204-
y_alg = casadi.MX.sym("y_alg", model.concatenated_algebraic.size)
221+
# Create the symbolic state vectors
222+
y_diff = casadi.MX.sym("y_diff", model.len_rhs)
223+
y_alg = casadi.MX.sym("y_alg", model.len_alg)
205224
y_casadi = casadi.vertcat(y_diff, y_alg)
206225
p_casadi = {}
207226
for name, value in inputs.items():
@@ -210,6 +229,13 @@ def set_up(self, model, inputs=None):
210229
else:
211230
p_casadi[name] = casadi.MX.sym(name, value.shape[0])
212231
p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()])
232+
# sensitivity vectors
233+
if self.solve_sensitivity_equations is True:
234+
S_x = casadi.MX.sym("S_x", model.len_rhs * p_casadi_stacked.shape[0])
235+
S_z = casadi.MX.sym("S_z", model.len_alg * p_casadi_stacked.shape[0])
236+
y_and_S = casadi.vertcat(y_diff, S_x, y_alg, S_z)
237+
else:
238+
y_and_S = y_casadi
213239

214240
def process(func, name, use_jacobian=None):
215241
def report(string):
@@ -258,16 +284,40 @@ def report(string):
258284
# Process with CasADi
259285
report(f"Converting {name} to CasADi")
260286
func = func.to_casadi(t_casadi, y_casadi, inputs=p_casadi)
287+
# Add sensitivity vectors to the rhs and algebraic equations
288+
if self.solve_sensitivity_equations is True:
289+
if name == "rhs":
290+
report(f"Creating sensitivity equations for rhs using CasADi")
291+
df_dx = casadi.jacobian(func, y_diff)
292+
df_dp = casadi.jacobian(func, p_casadi_stacked)
293+
if model.len_alg == 0:
294+
S_rhs = df_dx @ S_x + df_dp
295+
else:
296+
df_dz = casadi.jacobian(func, y_alg)
297+
S_rhs = df_dx @ S_x + df_dz @ S_z + df_dp
298+
func = casadi.vertcat(func, S_rhs)
299+
elif name == "initial_conditions":
300+
if model.len_rhs == 0 or model.len_alg == 0:
301+
S_0 = casadi.jacobian(func, p_casadi_stacked).reshape(
302+
(-1, 1)
303+
)
304+
func = casadi.vertcat(func, S_0)
305+
else:
306+
x0 = func[: model.len_rhs]
307+
z0 = func[model.len_rhs :]
308+
Sx_0 = casadi.jacobian(x0, p_casadi_stacked)
309+
Sz_0 = casadi.jacobian(z0, p_casadi_stacked)
310+
func = casadi.vertcat(x0, Sx_0, z0, Sz_0)
261311
if use_jacobian:
262312
report(f"Calculating jacobian for {name} using CasADi")
263-
jac_casadi = casadi.jacobian(func, y_casadi)
313+
jac_casadi = casadi.jacobian(func, y_and_S)
264314
jac = casadi.Function(
265-
name, [t_casadi, y_casadi, p_casadi_stacked], [jac_casadi]
315+
name, [t_casadi, y_and_S, p_casadi_stacked], [jac_casadi]
266316
)
267317
else:
268318
jac = None
269319
func = casadi.Function(
270-
name, [t_casadi, y_casadi, p_casadi_stacked], [func]
320+
name, [t_casadi, y_and_S, p_casadi_stacked], [func]
271321
)
272322
if name == "residuals":
273323
func_call = Residuals(func, name, model)
@@ -277,6 +327,7 @@ def report(string):
277327
jac_call = SolverCallable(jac, name + "_jac", model)
278328
else:
279329
jac_call = None
330+
280331
return func, func_call, jac_call
281332

282333
# Check for heaviside functions in rhs and algebraic and add discontinuity
@@ -324,8 +375,18 @@ def report(string):
324375
)[0]
325376
init_eval = InitialConditions(initial_conditions, model)
326377

378+
if self.solve_sensitivity_equations is True:
379+
init_eval.y_dummy = np.zeros(
380+
(
381+
model.len_rhs_and_alg * (np.vstack(list(inputs.values())).size + 1),
382+
1,
383+
)
384+
)
385+
else:
386+
init_eval.y_dummy = np.zeros((model.len_rhs_and_alg, 1))
387+
327388
# Process rhs, algebraic and event expressions
328-
rhs, rhs_eval, jac_rhs = process(model.concatenated_rhs, "RHS")
389+
rhs, rhs_eval, jac_rhs = process(model.concatenated_rhs, "rhs")
329390
algebraic, algebraic_eval, jac_algebraic = process(
330391
model.concatenated_algebraic, "algebraic"
331392
)
@@ -423,7 +484,7 @@ def _set_initial_conditions(self, model, inputs, update_rhs):
423484
y0_from_inputs = model.init_eval(inputs)
424485
# Reuse old solution for algebraic equations
425486
y0_from_model = model.y0
426-
len_rhs = model.concatenated_rhs.size
487+
len_rhs = model.len_rhs
427488
# update model.y0, which is used for initialising the algebraic solver
428489
if len_rhs == 0:
429490
model.y0 = y0_from_model
@@ -861,7 +922,7 @@ def __init__(self, function, name, model):
861922

862923
def __call__(self, t, y, inputs):
863924
y = y.reshape(-1, 1)
864-
if self.name in ["RHS", "algebraic", "residuals"]:
925+
if self.name in ["rhs", "algebraic", "residuals"]:
865926
pybamm.logger.debug(
866927
"Evaluating {} for {} at t={}".format(
867928
self.name, self.model.name, t * self.timescale
@@ -874,7 +935,7 @@ def __call__(self, t, y, inputs):
874935
def function(self, t, y, inputs):
875936
if self.form == "casadi":
876937
states_eval = self._function(t, y, inputs)
877-
if self.name in ["RHS", "algebraic", "residuals", "event"]:
938+
if self.name in ["rhs", "algebraic", "residuals", "event"]:
878939
return states_eval.full()
879940
else:
880941
# keep jacobians sparse
@@ -901,7 +962,6 @@ class InitialConditions(SolverCallable):
901962

902963
def __init__(self, function, model):
903964
super().__init__(function, "initial conditions", model)
904-
self.y_dummy = np.zeros(model.concatenated_initial_conditions.shape)
905965

906966
def __call__(self, inputs):
907967
if self.form == "casadi":

pybamm/solvers/casadi_solver.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,8 @@ def get_integrator(self, model, inputs, t_eval=None):
363363

364364
def _run_integrator(self, model, y0, inputs, t_eval):
365365
integrator, use_grid = self.integrators[model]
366-
len_rhs = model.concatenated_rhs.size
367-
y0_diff = y0[:len_rhs]
368-
y0_alg = y0[len_rhs:]
366+
y0_diff = y0[: model.len_rhs]
367+
y0_alg = y0[model.len_rhs :]
369368
try:
370369
# Try solving
371370
if use_grid is True:

pybamm/solvers/processed_symbolic_variable.py

-3
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,6 @@ def initialise_0D(self):
110110

111111
def initialise_1D(self):
112112
"Create a 1D variable"
113-
len_space = self.base_eval.shape[0]
114-
entries = np.empty((len_space, len(self.t_sol)))
115-
116113
# Evaluate the base_variable index-by-index
117114
for idx in range(len(self.t_sol)):
118115
t = self.t_sol[idx]

pybamm/solvers/processed_variable.py

+75
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#
22
# Processed Variable class
33
#
4+
import casadi
45
import numbers
56
import numpy as np
67
import pybamm
@@ -59,6 +60,10 @@ def __init__(self, base_variable, solution, known_evals=None, warn=True):
5960
self.known_evals = known_evals
6061
self.warn = warn
6162

63+
# Sensitivity starts off uninitialized, only set when called
64+
self._sensitivity = None
65+
self.solution_sensitivity = solution.sensitivity
66+
6267
# Set timescale
6368
self.timescale = solution.model.timescale.evaluate()
6469
self.t_pts = self.t_sol * self.timescale
@@ -553,6 +558,76 @@ def data(self):
553558
"Same as entries, but different name"
554559
return self.entries
555560

561+
@property
562+
def sensitivity(self):
563+
"""
564+
Returns a dictionary of sensitivity for each input parameter.
565+
The keys are the input parameters, and the value is a matrix of size
566+
(n_x * n_t, n_p), where n_x is the number of states, n_t is the number of time
567+
points, and n_p is the size of the input parameter
568+
"""
569+
# No sensitivity if there are no inputs
570+
if len(self.inputs) == 0:
571+
return {}
572+
# Otherwise initialise and return sensitivity
573+
if self._sensitivity is None:
574+
self.initialise_sensitivity()
575+
return self._sensitivity
576+
577+
def initialise_sensitivity(self):
578+
"Set up the sensitivity dictionary"
579+
inputs_stacked = casadi.vertcat(*[p for p in self.inputs.values()])
580+
581+
# Set up symbolic variables
582+
t_casadi = casadi.MX.sym("t")
583+
y_casadi = casadi.MX.sym("y", self.u_sol.shape[0])
584+
p_casadi = {
585+
name: casadi.MX.sym(name, value.shape[0])
586+
for name, value in self.inputs.items()
587+
}
588+
p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()])
589+
590+
# Convert variable to casadi format for differentiating
591+
var_casadi = self.base_variable.to_casadi(t_casadi, y_casadi, inputs=p_casadi)
592+
dvar_dy = casadi.jacobian(var_casadi, y_casadi)
593+
dvar_dp = casadi.jacobian(var_casadi, p_casadi_stacked)
594+
595+
# Convert to functions and evaluate index-by-index
596+
dvar_dy_func = casadi.Function(
597+
"dvar_dy", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dy]
598+
)
599+
dvar_dp_func = casadi.Function(
600+
"dvar_dp", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dp]
601+
)
602+
for idx in range(len(self.t_sol)):
603+
t = self.t_sol[idx]
604+
u = self.u_sol[:, idx]
605+
inp = inputs_stacked[:, idx]
606+
next_dvar_dy_eval = dvar_dy_func(t, u, inp)
607+
next_dvar_dp_eval = dvar_dp_func(t, u, inp)
608+
if idx == 0:
609+
dvar_dy_eval = next_dvar_dy_eval
610+
dvar_dp_eval = next_dvar_dp_eval
611+
else:
612+
dvar_dy_eval = casadi.diagcat(dvar_dy_eval, next_dvar_dy_eval)
613+
dvar_dp_eval = casadi.vertcat(dvar_dp_eval, next_dvar_dp_eval)
614+
615+
# Compute sensitivity
616+
dy_dp = self.solution_sensitivity["all"]
617+
S_var = dvar_dy_eval @ dy_dp + dvar_dp_eval
618+
619+
sensitivity = {"all": S_var}
620+
621+
# Add the individual sensitivity
622+
start = 0
623+
for name, inp in self.inputs.items():
624+
end = start + inp.shape[0]
625+
sensitivity[name] = S_var[:, start:end]
626+
start = end
627+
628+
# Save attribute
629+
self._sensitivity = sensitivity
630+
556631

557632
def eval_dimension_name(name, x, r, y, z):
558633
if name == "x":

pybamm/solvers/scipy_solver.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,24 @@ class ScipySolver(pybamm.BaseSolver):
2323
Any options to pass to the solver.
2424
Please consult `SciPy documentation <https://tinyurl.com/yafgqg9y>`_ for
2525
details.
26+
solve_sensitivity_equations : bool, optional
27+
See :class:`pybamm.BaseSolver`
2628
"""
2729

28-
def __init__(self, method="BDF", rtol=1e-6, atol=1e-6, extra_options=None):
29-
super().__init__(method, rtol, atol)
30+
def __init__(
31+
self,
32+
method="BDF",
33+
rtol=1e-6,
34+
atol=1e-6,
35+
extra_options=None,
36+
solve_sensitivity_equations=False,
37+
):
38+
super().__init__(
39+
method=method,
40+
rtol=rtol,
41+
atol=atol,
42+
solve_sensitivity_equations=solve_sensitivity_equations,
43+
)
3044
self.ode_solver = True
3145
self.extra_options = extra_options or {}
3246
self.name = "Scipy solver ({})".format(method)
@@ -52,8 +66,10 @@ def _integrate(self, model, t_eval, inputs=None):
5266
various diagnostic messages.
5367
5468
"""
69+
# Save inputs dictionary, and if necessary convert inputs to a casadi vector
70+
inputs_dict = inputs
5571
if model.convert_to_format == "casadi":
56-
inputs = casadi.vertcat(*[x for x in inputs.values()])
72+
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])
5773

5874
extra_options = {**self.extra_options, "rtol": self.rtol, "atol": self.atol}
5975

@@ -107,6 +123,14 @@ def event_fn(t, y):
107123
termination = "final time"
108124
t_event = None
109125
y_event = np.array(None)
110-
return pybamm.Solution(sol.t, sol.y, t_event, y_event, termination)
126+
return pybamm.Solution(
127+
sol.t,
128+
sol.y,
129+
t_event,
130+
y_event,
131+
termination,
132+
model=model,
133+
inputs=inputs_dict,
134+
)
111135
else:
112136
raise pybamm.SolverError(sol.message)

0 commit comments

Comments
 (0)