Skip to content

Commit 840f073

Browse files
committed
#1477 going to take out sensitivity=casadi option
1 parent 5214994 commit 840f073

File tree

6 files changed

+114
-118
lines changed

6 files changed

+114
-118
lines changed

pybamm/solvers/base_solver.py

+46-43
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,7 @@ class BaseSolver(object):
4242
the solution instance returned. At the moment this is only implemented for the
4343
IDAKLU solver.\
4444
- "explicit forward": explicitly formulate the sensitivity equations for
45-
the chosen input parameters. The formulation is as per
46-
"Park, S., Kato, D., Gima, Z., Klein, R., & Moura, S. (2018).\
47-
Optimal experimental design for parameterization of an electrochemical
48-
lithium-ion battery model. Journal of The Electrochemical\
49-
Society, 165(7), A1309.". See #1100 for details. At the moment this is only
45+
the chosen input parameters. . At the moment this is only
5046
implemented using convert_to_format = 'casadi'. \
5147
- see individual solvers for other options
5248
"""
@@ -60,7 +56,6 @@ def __init__(
6056
root_tol=1e-6,
6157
extrap_tol=0,
6258
max_steps="deprecated",
63-
sensitivity=None,
6459
):
6560
self._method = method
6661
self._rtol = rtol
@@ -79,7 +74,6 @@ def __init__(
7974
self.name = "Base solver"
8075
self.ode_solver = False
8176
self.algebraic_solver = False
82-
self.sensitivity = sensitivity
8377

8478
@property
8579
def method(self):
@@ -140,8 +134,6 @@ def copy(self):
140134
new_solver.models_set_up = {}
141135
return new_solver
142136

143-
144-
145137
def set_up(self, model, inputs=None, t_eval=None,
146138
calculate_sensitivites=False):
147139
"""Unpack model, perform checks, and calculate jacobian.
@@ -238,21 +230,17 @@ def set_up(self, model, inputs=None, t_eval=None,
238230
calculate_sensitivites = [p for p in inputs.keys()]
239231
else:
240232
calculate_sensitivites = []
233+
234+
calculate_sensitivites_explicit = False
235+
if calculate_sensitivites and not isinstance(self, pybamm.IDAKLUSolver):
236+
calculate_sensitivites_explicit = True
237+
241238
# save sensitivity parameters so we can identify them later on
242239
# (FYI: this is used in the Solution class)
243240
model.calculate_sensitivities = calculate_sensitivites
244-
model.len_rhs_sens = model.len_rhs * len(calculate_sensitivites)
245-
model.len_alg_sens = model.len_alg * len(calculate_sensitivites)
246-
247-
# Only allow solving explicit sensitivity equations with the casadi format for now
248-
if (
249-
self.sensitivity == "explicit forward"
250-
and model.convert_to_format != "casadi"
251-
):
252-
raise NotImplementedError(
253-
"model should be converted to casadi format in order to solve "
254-
"explicit sensitivity equations"
255-
)
241+
if calculate_sensitivites_explicit:
242+
model.len_rhs_sens = model.len_rhs * len(calculate_sensitivites)
243+
model.len_alg_sens = model.len_alg * len(calculate_sensitivites)
256244

257245
if model.convert_to_format != "casadi":
258246
# Create Jacobian from concatenated rhs and algebraic
@@ -275,7 +263,7 @@ def set_up(self, model, inputs=None, t_eval=None,
275263
p_casadi[name] = casadi.MX.sym(name, value.shape[0])
276264
p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()])
277265
# sensitivity vectors
278-
if self.sensitivity == "explicit forward":
266+
if calculate_sensitivites_explicit:
279267
pS_casadi_stacked = casadi.vertcat(
280268
*[p_casadi[name] for name in calculate_sensitivites]
281269
)
@@ -297,15 +285,19 @@ def report(string):
297285
if model.convert_to_format == "jax":
298286
report(f"Converting {name} to jax")
299287
func = pybamm.EvaluatorJax(func)
300-
if calculate_sensitivites:
288+
jacp = None
289+
if calculate_sensitivites_explicit:
290+
raise NotImplementedError(
291+
"sensitivities using convert_to_format = 'jax' "
292+
"only implemented for IDAKLUSolver"
293+
)
294+
elif calculate_sensitivites:
301295
report((
302296
f"Calculating sensitivities for {name} with respect "
303297
f"to parameters {calculate_sensitivites} using jax"
304298
))
305299
jacp = func.get_sensitivities()
306300
jacp = jacp.evaluate
307-
else:
308-
jacp = None
309301
if use_jacobian:
310302
report(f"Calculating jacobian for {name} using jax")
311303
jac = func.get_jacobian()
@@ -319,6 +311,10 @@ def report(string):
319311
# Process with pybamm functions, optionally converting
320312
# to python evaluator
321313
if calculate_sensitivites:
314+
raise NotImplementedError(
315+
"sensitivities only implemented with "
316+
"convert_to_format = 'casadi' or convert_to_format = 'jax'"
317+
)
322318
report((
323319
f"Calculating sensitivities for {name} with respect "
324320
f"to parameters {calculate_sensitivites}"
@@ -362,9 +358,16 @@ def jacp(*args, **kwargs):
362358
report(f"Converting {name} to CasADi")
363359
func = func.to_casadi(t_casadi, y_casadi, inputs=p_casadi)
364360
# Add sensitivity vectors to the rhs and algebraic equations
365-
if self.sensitivity == "explicit forward":
361+
jacp = None
362+
if calculate_sensitivites_explicit:
363+
# The formulation is as per Park, S., Kato, D., Gima, Z., Klein, R.,
364+
# & Moura, S. (2018). Optimal experimental design for
365+
# parameterization of an electrochemical lithium-ion battery model.
366+
# Journal of The Electrochemical Society, 165(7), A1309.". See #1100
367+
# for details
366368
if name == "rhs" and model.len_rhs > 0:
367-
report("Creating sensitivity equations for rhs using CasADi")
369+
report(
370+
"Creating explicit forward sensitivity equations for rhs using CasADi")
368371
df_dx = casadi.jacobian(func, y_diff)
369372
df_dp = casadi.jacobian(func, pS_casadi_stacked)
370373
S_x_mat = S_x.reshape(
@@ -383,7 +386,7 @@ def jacp(*args, **kwargs):
383386
func = casadi.vertcat(func, S_rhs)
384387
if name == "algebraic" and model.len_alg > 0:
385388
report(
386-
"Creating sensitivity equations for algebraic using CasADi"
389+
"Creating explicit forward sensitivity equations for algebraic using CasADi"
387390
)
388391
dg_dz = casadi.jacobian(func, y_alg)
389392
dg_dp = casadi.jacobian(func, pS_casadi_stacked)
@@ -401,7 +404,12 @@ def jacp(*args, **kwargs):
401404
(-1, 1)
402405
)
403406
func = casadi.vertcat(func, S_alg)
404-
elif name == "initial_conditions":
407+
if name == "residuals":
408+
raise NotImplementedError(
409+
"explicit forward equations not implimented for residuals"
410+
)
411+
412+
if name == "initial_conditions":
405413
if model.len_rhs == 0 or model.len_alg == 0:
406414
S_0 = casadi.jacobian(func, pS_casadi_stacked).reshape(
407415
(-1, 1)
@@ -417,16 +425,7 @@ def jacp(*args, **kwargs):
417425
(-1, 1)
418426
)
419427
func = casadi.vertcat(x0, Sx_0, z0, Sz_0)
420-
if use_jacobian:
421-
report(f"Calculating jacobian for {name} using CasADi")
422-
jac_casadi = casadi.jacobian(func, y_and_S)
423-
jac = casadi.Function(
424-
name, [t_casadi, y_and_S, p_casadi_stacked], [jac_casadi]
425-
)
426-
else:
427-
jac = None
428-
429-
if calculate_sensitivites and self.sensitivity != "explicit forward":
428+
elif calculate_sensitivites:
430429
report((
431430
f"Calculating sensitivities for {name} with respect "
432431
f"to parameters {calculate_sensitivites} using CasADi"
@@ -444,8 +443,14 @@ def jacp(*args, **kwargs):
444443
return {k: v(*args, **kwargs)
445444
for k, v in jacp_dict.items()}
446445

446+
if use_jacobian:
447+
report(f"Calculating jacobian for {name} using CasADi")
448+
jac_casadi = casadi.jacobian(func, y_and_S)
449+
jac = casadi.Function(
450+
name, [t_casadi, y_and_S, p_casadi_stacked], [jac_casadi]
451+
)
447452
else:
448-
jacp = None
453+
jac = None
449454

450455
func = casadi.Function(
451456
name, [t_casadi, y_and_S, p_casadi_stacked], [func]
@@ -538,7 +543,7 @@ def jacp(*args, **kwargs):
538543
)[0]
539544
init_eval = InitialConditions(initial_conditions, model)
540545

541-
if self.sensitivity == "explicit forward":
546+
if calculate_sensitivites_explicit:
542547
y0_total_size = (
543548
model.len_rhs + model.len_rhs_sens
544549
+ model.len_alg + model.len_alg_sens
@@ -555,7 +560,6 @@ def jacp(*args, **kwargs):
555560

556561
# Calculate initial conditions
557562
model.y0 = init_eval(inputs)
558-
print('YYYYY', model.y0)
559563

560564
casadi_terminate_events = []
561565
terminate_events_eval = []
@@ -726,7 +730,6 @@ def _set_initial_conditions(self, model, inputs, update_rhs):
726730
model.y0 = casadi.Function("y0", [symbolic_inputs], [y0])
727731
else:
728732
model.y0 = y0
729-
print('ASDF', model.y0)
730733

731734
def calculate_consistent_state(self, model, time=0, inputs=None):
732735
"""

pybamm/solvers/casadi_solver.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,6 @@ class CasadiSolver(pybamm.BaseSolver):
6262
Any options to pass to the CasADi integrator when calling the integrator.
6363
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
6464
details.
65-
sensitivity : str, optional
66-
Whether (and how) to calculate sensitivities when solving. Options are:
67-
68-
- None: no sensitivities
69-
- "explicit forward": explicitly formulate the sensitivity equations. \
70-
See :class:`pybamm.BaseSolver`
71-
- "casadi": use casadi to differentiate through the integrator
7265
"""
7366

7467
def __init__(
@@ -83,7 +76,6 @@ def __init__(
8376
extrap_tol=0,
8477
extra_options_setup=None,
8578
extra_options_call=None,
86-
sensitivity=None,
8779
):
8880
super().__init__(
8981
"problem dependent",
@@ -92,7 +84,6 @@ def __init__(
9284
root_method,
9385
root_tol,
9486
extrap_tol,
95-
sensitivity=sensitivity,
9687
)
9788
if mode in ["safe", "fast", "fast with events", "safe without grid"]:
9889
self.mode = mode
@@ -138,6 +129,9 @@ def _integrate(self, model, t_eval, inputs_dict=None):
138129

139130
# Record whether there are any symbolic inputs
140131
inputs_dict = inputs_dict or {}
132+
has_symbolic_inputs = any(
133+
isinstance(v, casadi.MX) for v in inputs_dict.values()
134+
)
141135

142136
# convert inputs to casadi format
143137
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])
@@ -176,7 +170,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
176170
else:
177171
# Create integrator without grid, which will be called repeatedly
178172
# This is necessary for casadi to compute sensitivities
179-
self.create_integrator(model, inputs_dict)
173+
self.create_integrator(model, inputs)
180174
solution = self._run_integrator(
181175
model, model.y0, inputs_dict, inputs, t_eval
182176
)
@@ -216,7 +210,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
216210
# in "safe without grid" mode,
217211
# create integrator once, without grid,
218212
# to avoid having to create several times
219-
self.create_integrator(model, inputs_dict)
213+
self.create_integrator(model, inputs)
220214
# Initialize solution
221215
solution = pybamm.Solution(
222216
np.array([t]), y0, model, inputs_dict,
@@ -258,7 +252,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
258252

259253
if self.mode == "safe":
260254
# update integrator with the grid
261-
self.create_integrator(model, inputs_dict, t_window)
255+
self.create_integrator(model, inputs, t_window)
262256
# Try to solve with the current global step, if it fails then
263257
# halve the step size and try again.
264258
try:

pybamm/solvers/idaklu_solver.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def __init__(
6262
if idaklu_spec is None:
6363
raise ImportError("KLU is not installed")
6464

65+
if sensitivity == "explicit forward":
66+
raise NotImplementedError(
67+
"Cannot use explicit forward equations with IDAKLUSolver"
68+
)
69+
6570
super().__init__(
6671
"ida",
6772
rtol,
@@ -188,12 +193,9 @@ def _integrate(self, model, t_eval, inputs_dict=None):
188193
atol = self._atol
189194

190195
y0 = model.y0
191-
print('idaklu, y0', y0)
192196
if isinstance(y0, casadi.DM):
193197
y0 = y0.full().flatten()
194198

195-
print('idaklu, y0', y0)
196-
197199
rtol = self._rtol
198200
atol = self._check_atol_type(atol, y0.size)
199201

pybamm/solvers/processed_variable.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
4646
self.auxiliary_domains = base_variables[0].auxiliary_domains
4747
self.warn = warn
4848

49-
self.symbolic_inputs = solution._symbolic_inputs
49+
self.symbolic_inputs = solution.has_symbolic_inputs
5050

5151
self.u_sol = solution.y
5252
self.y_sym = solution._y_sym

tests/integration/test_solvers/test_idaklu.py

+42
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,47 @@ def test_on_spme(self):
1919
solution = pybamm.IDAKLUSolver().solve(model, t_eval)
2020
np.testing.assert_array_less(1, solution.t.size)
2121

22+
def test_on_spme_sensitivities(self):
23+
param_name = "Negative electrode conductivity [S.m-1]"
24+
neg_electrode_cond = 100.0
25+
model = pybamm.lithium_ion.SPMe()
26+
geometry = model.default_geometry
27+
param = model.default_parameter_values
28+
param.update({param_name: "[input]"})
29+
inputs = {param_name: neg_electrode_cond}
30+
param.process_model(model)
31+
param.process_geometry(geometry)
32+
mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts)
33+
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
34+
disc.process_model(model)
35+
t_eval = np.linspace(0, 3600, 100)
36+
solver = pybamm.IDAKLUSolver()
37+
solution = solver.solve(
38+
model, t_eval,
39+
inputs=inputs,
40+
calculate_sensitivities=True,
41+
)
42+
np.testing.assert_array_less(1, solution.t.size)
43+
44+
# evaluate the sensitivities using idas
45+
dyda_ida = solution.sensitivities[param_name]
46+
47+
# evaluate the sensitivities using finite difference
48+
h = 1e-6
49+
sol_plus = solver.solve(
50+
model, t_eval,
51+
inputs={param_name: neg_electrode_cond + 0.5 * h}
52+
)
53+
sol_neg = solver.solve(
54+
model, t_eval,
55+
inputs={param_name: neg_electrode_cond - 0.5 * h}
56+
)
57+
dyda_fd = (sol_plus.y - sol_neg.y) / h
58+
59+
np.testing.assert_array_almost_equal(
60+
dyda_ida, dyda_fd
61+
)
62+
2263
def test_set_tol_by_variable(self):
2364
model = pybamm.lithium_ion.SPMe()
2465
geometry = model.default_geometry
@@ -68,6 +109,7 @@ def test_changing_grid(self):
68109
if __name__ == "__main__":
69110
print("Add -v for more debug output")
70111

112+
pybamm.set_logging_level('INFO')
71113
if "-v" in sys.argv:
72114
debug = True
73115
pybamm.settings.debug_mode = True

0 commit comments

Comments
 (0)