Skip to content

Commit 41565da

Browse files
committed
#1477 python sensitivities seem ok, working on casadi
1 parent c7ddbf5 commit 41565da

File tree

3 files changed

+132
-28
lines changed

3 files changed

+132
-28
lines changed

pybamm/expression_tree/concatenations.py

+12
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,18 @@ def __str__(self):
4242
out = out[:-2] + ")"
4343
return out
4444

45+
def _diff(self, variable):
46+
""" See :meth:`pybamm.Symbol._diff()`. """
47+
children_diffs = [
48+
child.diff(variable) for child in self.cached_children
49+
]
50+
if len(children_diffs) == 1:
51+
diff = children_diffs[0]
52+
else:
53+
diff = self.__class__(children_diffs)
54+
55+
return diff
56+
4557
def get_children_domains(self, children):
4658
# combine domains from children
4759
domain = []

pybamm/solvers/base_solver.py

+81-28
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ def copy(self):
134134
new_solver.models_set_up = {}
135135
return new_solver
136136

137-
def set_up(self, model, inputs=None, t_eval=None):
137+
def set_up(self, model, inputs=None, t_eval=None,
138+
calculate_sensitivites=False):
138139
"""Unpack model, perform checks, and calculate jacobian.
139140
140141
Parameters
@@ -146,6 +147,10 @@ def set_up(self, model, inputs=None, t_eval=None):
146147
Any input parameters to pass to the model when solving
147148
t_eval : numeric type, optional
148149
The times (in seconds) at which to compute the solution
150+
calculate_sensitivites : list of str or bool
151+
If true, solver calculates sensitivities of all input parameters.
152+
If only a subset of sensitivities are required, can also pass a
153+
list of input parameter names
149154
150155
"""
151156
pybamm.logger.info("Start solver set-up")
@@ -209,14 +214,28 @@ def set_up(self, model, inputs=None, t_eval=None):
209214
)
210215
model.convert_to_format = "casadi"
211216

217+
# find all the input parameters in the model
218+
input_parameters = {}
219+
for equation in [model.concatenated_rhs,
220+
model.concatenated_algebraic,
221+
model.concatenated_initial_conditions]:
222+
input_parameters.update({
223+
symbol._id: symbol for symbol in equation.pre_order()
224+
if isinstance(symbol, pybamm.InputParameter)
225+
})
226+
227+
# from here on, calculate_sensitivites is now only a list
228+
if isinstance(calculate_sensitivites, bool):
229+
if calculate_sensitivites:
230+
calculate_sensitivites = [p for p in inputs.keys()]
231+
else:
232+
calculate_sensitivites = []
233+
212234
if model.convert_to_format != "casadi":
213235
# Create Jacobian from concatenated rhs and algebraic
214236
y = pybamm.StateVector(slice(0, model.concatenated_initial_conditions.size))
215237
# set up Jacobian object, for re-use of dict
216238
jacobian = pybamm.Jacobian()
217-
jacobian_parameters = {
218-
p: pybamm.Jacobian() for p in inputs.keys()
219-
}
220239

221240
else:
222241
# Convert model attributes to casadi
@@ -244,8 +263,11 @@ def report(string):
244263
if model.convert_to_format == "jax":
245264
report(f"Converting {name} to jax")
246265
func = pybamm.EvaluatorJax(func)
247-
if self.sensitivity:
248-
report(f"Calculating sensitivities for {name} using jax")
266+
if calculate_sensitivites:
267+
report((
268+
f"Calculating sensitivities for {name} with respect "
269+
f"to parameters {calculate_sensitivites} using jax"
270+
))
249271
jacp_dict = func.get_sensitivities()
250272
else:
251273
jacp_dict = None
@@ -261,19 +283,24 @@ def report(string):
261283
elif model.convert_to_format != "casadi":
262284
# Process with pybamm functions, optionally converting
263285
# to python evaluator
264-
if self.sensitivity:
265-
report(f"Calculating sensitivities for {name}")
286+
print('calculate_sensitivites = ', calculate_sensitivites)
287+
if calculate_sensitivites:
288+
report((
289+
f"Calculating sensitivities for {name} with respect "
290+
f"to parameters {calculate_sensitivites}"
291+
))
292+
print(type(func))
266293
jacp_dict = {
267-
p: jwrtp.jac(func, pybamm.InputParameter(p))
268-
for jwrtp, p in
269-
zip(jacobian_parameters, inputs.keys())
294+
p: func.diff(pybamm.InputParameter(p))
295+
for p in calculate_sensitivites
270296
}
271297
if model.convert_to_format == "python":
272298
report(f"Converting sensitivities for {name} to python")
273299
jacp_dict = {
274300
p: pybamm.EvaluatorPython(jacp)
275301
for p, jacp in jacp_dict.items()
276302
}
303+
jacp_dict = {k: v.evaluate for k, v in jacp_dict.items()}
277304
else:
278305
jacp_dict = None
279306

@@ -306,12 +333,18 @@ def report(string):
306333
else:
307334
jac = None
308335

309-
if self.sensitivity:
310-
report(f"Calculating sensitivities for {name} using CasADi")
311-
jacp_dict = {
312-
name: casadi.jacobian(func, p)
313-
for name, p in p_casadi.items()
314-
}
336+
if calculate_sensitivites:
337+
report((
338+
f"Calculating sensitivities for {name} with respect "
339+
f"to parameters {calculate_sensitivites} using CasADi"
340+
))
341+
jacp_dict = {}
342+
for pname in calculate_sensitivites:
343+
p_diff = casadi.jacobian(func, p_casadi[pname])
344+
jacp_dict[pname] = casadi.Function(
345+
name, [t_casadi, y_casadi, p_casadi_stacked],
346+
[p_diff]
347+
)
315348
else:
316349
jacp_dict = None
317350

@@ -326,7 +359,12 @@ def report(string):
326359
jac_call = SolverCallable(jac, name + "_jac", model)
327360
else:
328361
jac_call = None
329-
return func, func_call, jac_call
362+
if jacp_dict is not None:
363+
jacp_call = {
364+
k: SolverCallable(v, name + "_sensitivity_wrt_" + k, model)
365+
for k, v in jacp_dict.items()
366+
}
367+
return func, func_call, jac_call, jacp_call
330368

331369
# Check for heaviside and modulo functions in rhs and algebraic and add
332370
# discontinuity events if these exist.
@@ -400,8 +438,8 @@ def report(string):
400438
init_eval = InitialConditions(initial_conditions, model)
401439

402440
# Process rhs, algebraic and event expressions
403-
rhs, rhs_eval, jac_rhs = process(model.concatenated_rhs, "RHS")
404-
algebraic, algebraic_eval, jac_algebraic = process(
441+
rhs, rhs_eval, jac_rhs, jacp_rhs = process(model.concatenated_rhs, "RHS")
442+
algebraic, algebraic_eval, jac_algebraic, jacp_algebraic = process(
405443
model.concatenated_algebraic, "algebraic"
406444
)
407445

@@ -486,19 +524,23 @@ def report(string):
486524
# No rhs equations: residuals is algebraic only
487525
model.residuals_eval = Residuals(algebraic, "residuals", model)
488526
model.jacobian_eval = jac_algebraic
527+
model.sensitivities_eval = jacp_algebraic
489528
elif len(model.algebraic) == 0:
490529
# No algebraic equations: residuals is rhs only
491530
model.residuals_eval = Residuals(rhs, "residuals", model)
492531
model.jacobian_eval = jac_rhs
532+
model.sensitivities_eval = jacp_rhs
493533
# Calculate consistent initial conditions for the algebraic equations
494534
else:
495535
all_states = pybamm.NumpyConcatenation(
496536
model.concatenated_rhs, model.concatenated_algebraic
497537
)
498538
# Process again, uses caching so should be quick
499-
residuals_eval, jacobian_eval = process(all_states, "residuals")[1:]
539+
residuals_eval, jacobian_eval, jacobian_wrtp_eval = \
540+
process(all_states, "residuals")[1:]
500541
model.residuals_eval = residuals_eval
501542
model.jacobian_eval = jacobian_eval
543+
model.sensitivities_eval = jacobian_wrtp_eval
502544

503545
pybamm.logger.info("Finish solver set-up")
504546

@@ -589,6 +631,7 @@ def solve(
589631
inputs=None,
590632
initial_conditions=None,
591633
nproc=None,
634+
calculate_sensitivities=False
592635
):
593636
"""
594637
Execute the solver setup and calculate the solution of the model at
@@ -614,6 +657,10 @@ def solve(
614657
nproc : int, optional
615658
Number of processes to use when solving for more than one set of input
616659
parameters. Defaults to value returned by "os.cpu_count()".
660+
calculate_sensitivites : list of str or bool
661+
If true, solver calculates sensitivities of all input parameters.
662+
If only a subset of sensitivities are required, can also pass a
663+
list of input parameter names
617664
618665
Returns
619666
-------
@@ -690,7 +737,8 @@ def solve(
690737
# not depend on input parameters. Thefore only `ext_and_inputs[0]`
691738
# is passed to `set_up`.
692739
# See https://github.com/pybamm-team/PyBaMM/pull/1261
693-
self.set_up(model, ext_and_inputs_list[0], t_eval)
740+
self.set_up(model, ext_and_inputs_list[0], t_eval,
741+
calculate_sensitivities)
694742
self.models_set_up.update(
695743
{model: {"initial conditions": model.concatenated_initial_conditions}}
696744
)
@@ -701,7 +749,8 @@ def solve(
701749
# If the new initial conditions are different, set up again
702750
# Doing the whole setup again might be slow, but no need to prematurely
703751
# optimize this
704-
self.set_up(model, ext_and_inputs_list[0], t_eval)
752+
self.set_up(model, ext_and_inputs_list[0], t_eval,
753+
calculate_sensitivities)
705754
self.models_set_up[model][
706755
"initial conditions"
707756
] = model.concatenated_initial_conditions
@@ -951,6 +1000,9 @@ def step(
9511000
save : bool
9521001
Turn on to store the solution of all previous timesteps
9531002
1003+
1004+
1005+
9541006
Raises
9551007
------
9561008
:class:`pybamm.ModelError`
@@ -1241,12 +1293,13 @@ def __init__(self, function, name, model):
12411293
self.timescale = self.model.timescale_eval
12421294

12431295
def __call__(self, t, y, inputs):
1244-
if self.name in ["RHS", "algebraic", "residuals"]:
1245-
pybamm.logger.debug(
1246-
"Evaluating {} for {} at t={}".format(
1247-
self.name, self.model.name, t * self.timescale
1248-
)
1296+
pybamm.logger.debug(
1297+
"Evaluating {} for {} at t={}".format(
1298+
self.name, self.model.name, t * self.timescale
12491299
)
1300+
)
1301+
if self.name in ["RHS", "algebraic", "residuals"]:
1302+
12501303
return self.function(t, y, inputs).flatten()
12511304
else:
12521305
return self.function(t, y, inputs)

tests/unit/test_solvers/test_base_solver.py

+39
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,45 @@ def test_extrapolation_warnings(self):
322322
with self.assertWarns(pybamm.SolverWarning):
323323
solver.solve(model, t_eval=[0, 1])
324324

325+
def test_sensitivities(self):
326+
pybamm.set_logging_level('DEBUG')
327+
328+
def exact_diff_a(v, a, b):
329+
return v**2 + 2 * a
330+
331+
def exact_diff_b(v, a, b):
332+
return v
333+
334+
for f in ['', 'python', 'casadi']:
335+
model = pybamm.BaseModel()
336+
v = pybamm.Variable("v")
337+
a = pybamm.InputParameter("a")
338+
b = pybamm.InputParameter("b")
339+
model.rhs = {v: a * v**2 + b * v + a**2}
340+
model.initial_conditions = {v: 1}
341+
model.convert_to_format = f
342+
solver = pybamm.ScipySolver()
343+
solver.set_up(model, calculate_sensitivites=True,
344+
inputs={'a': 0, 'b': 0})
345+
for v_value in [0.1, -0.2, 1.5, 8.4]:
346+
for a_value in [0.12, 1.5]:
347+
for b_value in [0.82, 1.9]:
348+
y = np.array([v_value])
349+
t = 0
350+
inputs = {'a': a_value, 'b': b_value}
351+
352+
self.assertAlmostEqual(
353+
model.sensitivities_eval['a'](
354+
t=0, y=y, inputs=inputs
355+
),
356+
exact_diff_a(v_value, a_value, b_value)
357+
)
358+
self.assertAlmostEqual(
359+
model.sensitivities_eval['b'](
360+
t=0, y=y, inputs=inputs
361+
),
362+
exact_diff_b(v_value, a_value, b_value)
363+
)
325364

326365
if __name__ == "__main__":
327366
print("Add -v for more debug output")

0 commit comments

Comments
 (0)