@@ -34,6 +34,15 @@ class BaseSolver(object):
34
34
The tolerance for the initial-condition solver (default is 1e-6).
35
35
extrap_tol : float, optional
36
36
The tolerance to assert whether extrapolation occurs or not. Default is 0.
37
+ sensitivity : str, optional
38
+ Whether (and how) to calculate sensitivities when solving. Options are:
39
+ - "explicit forward": explicitly formulate the sensitivity equations. \
40
+ The formulation is as per "Park, S., Kato, D., Gima, Z., \
41
+ Klein, R., & Moura, S. (2018). Optimal experimental design for parameterization\
42
+ of an electrochemical lithium-ion battery model. Journal of The Electrochemical\
43
+ Society, 165(7), A1309.". See #1100 for details \
44
+ - see individual solvers for other options
45
+
37
46
"""
38
47
39
48
def __init__ (
@@ -45,6 +54,7 @@ def __init__(
45
54
root_tol = 1e-6 ,
46
55
extrap_tol = 0 ,
47
56
max_steps = "deprecated" ,
57
+ sensitivity = None
48
58
):
49
59
self ._method = method
50
60
self ._rtol = rtol
@@ -63,6 +73,7 @@ def __init__(
63
73
self .name = "Base solver"
64
74
self .ode_solver = False
65
75
self .algebraic_solver = False
76
+ self .sensitivity = sensitivity
66
77
67
78
@property
68
79
def method (self ):
@@ -203,6 +214,10 @@ def set_up(self, model, inputs=None, t_eval=None):
203
214
y = pybamm .StateVector (slice (0 , model .concatenated_initial_conditions .size ))
204
215
# set up Jacobian object, for re-use of dict
205
216
jacobian = pybamm .Jacobian ()
217
+ jacobian_parameters = {
218
+ p : pybamm .Jacobian () for p in inputs .keys ()
219
+ }
220
+
206
221
else :
207
222
# Convert model attributes to casadi
208
223
t_casadi = casadi .MX .sym ("t" )
@@ -225,32 +240,56 @@ def report(string):
225
240
226
241
if use_jacobian is None :
227
242
use_jacobian = model .use_jacobian
228
- if model .convert_to_format != "casadi" :
229
- # Process with pybamm functions
230
243
231
- if model .convert_to_format == "jax" :
232
- report (f"Converting { name } to jax" )
233
- jax_func = pybamm .EvaluatorJax (func )
244
+ if model .convert_to_format == "jax" :
245
+ report (f"Converting { name } to jax" )
246
+ func = pybamm .EvaluatorJax (func )
247
+ if self .sensitivity :
248
+ report (f"Calculating sensitivities for { name } using jax" )
249
+ jacp_dict = func .get_sensitivities ()
250
+ else :
251
+ jacp_dict = None
252
+ if use_jacobian :
253
+ report (f"Calculating jacobian for { name } using jax" )
254
+ jac = func .get_jacobian ()
255
+ jac = jac .evaluate
256
+ else :
257
+ jac = None
258
+
259
+ func = func .evaluate
260
+
261
+ elif model .convert_to_format != "casadi" :
262
+ # Process with pybamm functions, optionally converting
263
+ # to python evaluator
264
+ if self .sensitivity :
265
+ report (f"Calculating sensitivities for { name } " )
266
+ jacp_dict = {
267
+ p : jwrtp .jac (func , pybamm .InputParameter (p ))
268
+ for jwrtp , p in
269
+ zip (jacobian_parameters , inputs .keys ())
270
+ }
271
+ if model .convert_to_format == "python" :
272
+ report (f"Converting sensitivities for { name } to python" )
273
+ jacp_dict = {
274
+ p : pybamm .EvaluatorPython (jacp )
275
+ for p , jacp in jacp_dict .items ()
276
+ }
277
+ else :
278
+ jacp_dict = None
234
279
235
280
if use_jacobian :
236
281
report (f"Calculating jacobian for { name } " )
237
282
jac = jacobian .jac (func , y )
238
283
if model .convert_to_format == "python" :
239
284
report (f"Converting jacobian for { name } to python" )
240
285
jac = pybamm .EvaluatorPython (jac )
241
- elif model .convert_to_format == "jax" :
242
- report (f"Converting jacobian for { name } to jax" )
243
- jac = jax_func .get_jacobian ()
244
286
jac = jac .evaluate
245
287
else :
246
288
jac = None
247
289
248
290
if model .convert_to_format == "python" :
249
291
report (f"Converting { name } to python" )
250
292
func = pybamm .EvaluatorPython (func )
251
- if model .convert_to_format == "jax" :
252
- report (f"Converting { name } to jax" )
253
- func = jax_func
254
293
255
294
func = func .evaluate
256
295
@@ -266,6 +305,16 @@ def report(string):
266
305
)
267
306
else :
268
307
jac = None
308
+
309
+ if self .sensitivity :
310
+ report (f"Calculating sensitivities for { name } using CasADi" )
311
+ jacp_dict = {
312
+ name : casadi .jacobian (func , p )
313
+ for name , p in p_casadi .items ()
314
+ }
315
+ else :
316
+ jacp_dict = None
317
+
269
318
func = casadi .Function (
270
319
name , [t_casadi , y_casadi , p_casadi_stacked ], [func ]
271
320
)
0 commit comments