Skip to content

Commit c7ddbf5

Browse files
committed
#1477 draft out a test for idaklu and changes to base solver for sensitivities
1 parent f7bb791 commit c7ddbf5

File tree

4 files changed

+156
-21
lines changed

4 files changed

+156
-21
lines changed

pybamm/expression_tree/operations/evaluate.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def __init__(self, symbol):
557557
constants[symbol_id] = jax.device_put(constants[symbol_id])
558558

559559
# get a list of constant arguments to input to the function
560-
arg_list = [
560+
self._arg_list = [
561561
id_to_python_variable(symbol_id, True) for symbol_id in constants.keys()
562562
]
563563

@@ -578,8 +578,8 @@ def __init__(self, symbol):
578578

579579
# add function def to first line
580580
args = "t=None, y=None, y_dot=None, inputs=None, known_evals=None"
581-
if arg_list:
582-
args = ",".join(arg_list) + ", " + args
581+
if self._arg_list:
582+
args = ",".join(self._arg_list) + ", " + args
583583
python_str = "def evaluate_jax({}):\n".format(args) + python_str
584584

585585
# calculate the final variable that will output the result of calling `evaluate`
@@ -604,17 +604,33 @@ def __init__(self, symbol):
604604
compiled_function = compile(python_str, result_var, "exec")
605605
exec(compiled_function)
606606

607-
n = len(arg_list)
608-
static_argnums = tuple(static_argnums)
609-
self._jit_evaluate = jax.jit(self._evaluate_jax, static_argnums=static_argnums)
607+
self._static_argnums = tuple(static_argnums)
608+
self._jit_evaluate = jax.jit(self._evaluate_jax,
609+
static_argnums=self._static_argnums)
610610

611-
# store a jit version of evaluate_jax's jacobian
611+
def get_jacobian(self):
612+
n = len(self._arg_list)
613+
614+
# forward mode autodiff wrt y, which is argument 1 after arg_list
612615
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=1 + n)
613-
self._jac_evaluate = jax.jit(jacobian_evaluate, static_argnums=static_argnums)
614616

615-
def get_jacobian(self):
617+
self._jac_evaluate = jax.jit(jacobian_evaluate,
618+
static_argnums=self._static_argnums)
619+
620+
return EvaluatorJaxJacobian(self._jac_evaluate, self._constants)
621+
622+
def get_sensitivities(self):
623+
n = len(self._arg_list)
624+
625+
# forward mode autodiff wrt inputs, which is argument 3 after arg_list
626+
jacobian_evaluate = jax.jacfwd(self._evaluate_jax, argnums=3 + n)
627+
628+
self._sens_evaluate = jax.jit(jacobian_evaluate,
629+
static_argnums=self._static_argnums)
630+
616631
return EvaluatorJaxJacobian(self._jac_evaluate, self._constants)
617632

633+
618634
def debug(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
619635
# generated code assumes y is a column vector
620636
if y is not None and y.ndim == 1:

pybamm/solvers/base_solver.py

+60-11
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ class BaseSolver(object):
3434
The tolerance for the initial-condition solver (default is 1e-6).
3535
extrap_tol : float, optional
3636
The tolerance to assert whether extrapolation occurs or not. Default is 0.
37+
sensitivity : str, optional
38+
Whether (and how) to calculate sensitivities when solving. Options are:
39+
- "explicit forward": explicitly formulate the sensitivity equations. \
40+
The formulation is as per "Park, S., Kato, D., Gima, Z., \
41+
Klein, R., & Moura, S. (2018). Optimal experimental design for parameterization\
42+
of an electrochemical lithium-ion battery model. Journal of The Electrochemical\
43+
Society, 165(7), A1309.". See #1100 for details \
44+
- see individual solvers for other options
45+
3746
"""
3847

3948
def __init__(
@@ -45,6 +54,7 @@ def __init__(
4554
root_tol=1e-6,
4655
extrap_tol=0,
4756
max_steps="deprecated",
57+
sensitivity=None
4858
):
4959
self._method = method
5060
self._rtol = rtol
@@ -63,6 +73,7 @@ def __init__(
6373
self.name = "Base solver"
6474
self.ode_solver = False
6575
self.algebraic_solver = False
76+
self.sensitivity = sensitivity
6677

6778
@property
6879
def method(self):
@@ -203,6 +214,10 @@ def set_up(self, model, inputs=None, t_eval=None):
203214
y = pybamm.StateVector(slice(0, model.concatenated_initial_conditions.size))
204215
# set up Jacobian object, for re-use of dict
205216
jacobian = pybamm.Jacobian()
217+
jacobian_parameters = {
218+
p: pybamm.Jacobian() for p in inputs.keys()
219+
}
220+
206221
else:
207222
# Convert model attributes to casadi
208223
t_casadi = casadi.MX.sym("t")
@@ -225,32 +240,56 @@ def report(string):
225240

226241
if use_jacobian is None:
227242
use_jacobian = model.use_jacobian
228-
if model.convert_to_format != "casadi":
229-
# Process with pybamm functions
230243

231-
if model.convert_to_format == "jax":
232-
report(f"Converting {name} to jax")
233-
jax_func = pybamm.EvaluatorJax(func)
244+
if model.convert_to_format == "jax":
245+
report(f"Converting {name} to jax")
246+
func = pybamm.EvaluatorJax(func)
247+
if self.sensitivity:
248+
report(f"Calculating sensitivities for {name} using jax")
249+
jacp_dict = func.get_sensitivities()
250+
else:
251+
jacp_dict = None
252+
if use_jacobian:
253+
report(f"Calculating jacobian for {name} using jax")
254+
jac = func.get_jacobian()
255+
jac = jac.evaluate
256+
else:
257+
jac = None
258+
259+
func = func.evaluate
260+
261+
elif model.convert_to_format != "casadi":
262+
# Process with pybamm functions, optionally converting
263+
# to python evaluator
264+
if self.sensitivity:
265+
report(f"Calculating sensitivities for {name}")
266+
jacp_dict = {
267+
p: jwrtp.jac(func, pybamm.InputParameter(p))
268+
for jwrtp, p in
269+
zip(jacobian_parameters, inputs.keys())
270+
}
271+
if model.convert_to_format == "python":
272+
report(f"Converting sensitivities for {name} to python")
273+
jacp_dict = {
274+
p: pybamm.EvaluatorPython(jacp)
275+
for p, jacp in jacp_dict.items()
276+
}
277+
else:
278+
jacp_dict = None
234279

235280
if use_jacobian:
236281
report(f"Calculating jacobian for {name}")
237282
jac = jacobian.jac(func, y)
238283
if model.convert_to_format == "python":
239284
report(f"Converting jacobian for {name} to python")
240285
jac = pybamm.EvaluatorPython(jac)
241-
elif model.convert_to_format == "jax":
242-
report(f"Converting jacobian for {name} to jax")
243-
jac = jax_func.get_jacobian()
244286
jac = jac.evaluate
245287
else:
246288
jac = None
247289

248290
if model.convert_to_format == "python":
249291
report(f"Converting {name} to python")
250292
func = pybamm.EvaluatorPython(func)
251-
if model.convert_to_format == "jax":
252-
report(f"Converting {name} to jax")
253-
func = jax_func
254293

255294
func = func.evaluate
256295

@@ -266,6 +305,16 @@ def report(string):
266305
)
267306
else:
268307
jac = None
308+
309+
if self.sensitivity:
310+
report(f"Calculating sensitivities for {name} using CasADi")
311+
jacp_dict = {
312+
name: casadi.jacobian(func, p)
313+
for name, p in p_casadi.items()
314+
}
315+
else:
316+
jacp_dict = None
317+
269318
func = casadi.Function(
270319
name, [t_casadi, y_casadi, p_casadi_stacked], [func]
271320
)

pybamm/solvers/idaklu_solver.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ class IDAKLUSolver(pybamm.BaseSolver):
3838
The tolerance for the initial-condition solver (default is 1e-6).
3939
extrap_tol : float, optional
4040
The tolerance to assert whether extrapolation occurs or not (default is 0).
41+
sensitivity : str, optional
42+
Whether (and how) to calculate sensitivities when solving. Options are:
43+
- "explicit forward": explicitly formulate the sensitivity equations. \
44+
The formulation is as per "Park, S., Kato, D., Gima, Z., \
45+
Klein, R., & Moura, S. (2018). Optimal experimental design for parameterization\
46+
of an electrochemical lithium-ion battery model. Journal of The Electrochemical\
47+
Society, 165(7), A1309.". See #1100 for details \
48+
- "idas": use Sundials IDAS to compute forward sensitivities
49+
4150
"""
4251

4352
def __init__(
@@ -48,13 +57,21 @@ def __init__(
4857
root_tol=1e-6,
4958
extrap_tol=0,
5059
max_steps="deprecated",
60+
sensitivity="idas"
5161
):
5262

5363
if idaklu_spec is None:
5464
raise ImportError("KLU is not installed")
5565

5666
super().__init__(
57-
"ida", rtol, atol, root_method, root_tol, extrap_tol, max_steps
67+
"ida",
68+
rtol,
69+
atol,
70+
root_method,
71+
root_tol,
72+
extrap_tol,
73+
max_steps,
74+
sensitivity=sensitivity,
5875
)
5976
self.name = "IDA KLU solver"
6077

tests/unit/test_solvers/test_idaklu_solver.py

+53
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,59 @@ def test_ida_roberts_klu(self):
4646
true_solution = 0.1 * solution.t
4747
np.testing.assert_array_almost_equal(solution.y[0, :], true_solution)
4848

49+
def test_ida_roberts_klu_sensitivities(self):
50+
# this test implements a python version of the ida Roberts
51+
# example provided in sundials
52+
# see sundials ida examples pdf
53+
for form in ["python", "casadi"]:
54+
model = pybamm.BaseModel()
55+
model.convert_to_format = form
56+
u = pybamm.Variable("u")
57+
v = pybamm.Variable("v")
58+
a = pybamm.InputParameter("a")
59+
model.rhs = {u: a * v}
60+
model.algebraic = {v: 1 - v}
61+
model.initial_conditions = {u: 0, v: 1}
62+
model.events = [pybamm.Event("1", u - 0.2), pybamm.Event("2", v)]
63+
64+
disc = pybamm.Discretisation()
65+
disc.process_model(model)
66+
67+
solver = pybamm.IDAKLUSolver(root_method="lm")
68+
69+
t_eval = np.linspace(0, 3, 100)
70+
a_value = 0.1
71+
sol = solver.solve(model, t_eval, inputs={"a": a_value})
72+
73+
# test that final time is time of event
74+
# y = 0.1 t + y0 so y=0.2 when t=2
75+
np.testing.assert_array_almost_equal(sol.t[-1], 2.0)
76+
77+
# test that final value is the event value
78+
np.testing.assert_array_almost_equal(sol.y[0, -1], 0.2)
79+
80+
# test that y[1] remains constant
81+
np.testing.assert_array_almost_equal(
82+
sol.y[1, :], np.ones(sol.t.shape)
83+
)
84+
85+
# test that y[0] = to true solution
86+
true_solution = 0.1 * sol.t
87+
np.testing.assert_array_almost_equal(sol.y[0, :], true_solution)
88+
89+
# evaluate the sensitivities using idas
90+
dyda_ida = sol.sensitivities["a"]
91+
92+
# evaluate the sensitivities using finite difference
93+
h = 1e-6
94+
sol_plus = solver.solve(model, t_eval, inputs={"a": a_value + 0.5 * h})
95+
sol_neg = solver.solve(model, t_eval, inputs={"a": a_value - 0.5 * h})
96+
dyda_fd = (sol_plus.y - sol_neg.y) / h
97+
98+
np.testing.assert_array_almost_equal(
99+
dyda_ida, dyda_fd
100+
)
101+
49102
def test_set_atol(self):
50103
model = pybamm.lithium_ion.DFN()
51104
geometry = model.default_geometry

0 commit comments

Comments
 (0)