@@ -134,7 +134,8 @@ def copy(self):
134
134
new_solver .models_set_up = {}
135
135
return new_solver
136
136
137
- def set_up (self , model , inputs = None , t_eval = None ):
137
+ def set_up (self , model , inputs = None , t_eval = None ,
138
+ calculate_sensitivites = False ):
138
139
"""Unpack model, perform checks, and calculate jacobian.
139
140
140
141
Parameters
@@ -146,6 +147,10 @@ def set_up(self, model, inputs=None, t_eval=None):
146
147
Any input parameters to pass to the model when solving
147
148
t_eval : numeric type, optional
148
149
The times (in seconds) at which to compute the solution
150
+ calculate_sensitivites : list of str or bool
151
+ If true, solver calculates sensitivities of all input parameters.
152
+ If only a subset of sensitivities are required, can also pass a
153
+ list of input parameter names
149
154
150
155
"""
151
156
pybamm .logger .info ("Start solver set-up" )
@@ -209,14 +214,28 @@ def set_up(self, model, inputs=None, t_eval=None):
209
214
)
210
215
model .convert_to_format = "casadi"
211
216
217
+ # find all the input parameters in the model
218
+ input_parameters = {}
219
+ for equation in [model .concatenated_rhs ,
220
+ model .concatenated_algebraic ,
221
+ model .concatenated_initial_conditions ]:
222
+ input_parameters .update ({
223
+ symbol ._id : symbol for symbol in equation .pre_order ()
224
+ if isinstance (symbol , pybamm .InputParameter )
225
+ })
226
+
227
+ # from here on, calculate_sensitivites is now only a list
228
+ if isinstance (calculate_sensitivites , bool ):
229
+ if calculate_sensitivites :
230
+ calculate_sensitivites = [p for p in inputs .keys ()]
231
+ else :
232
+ calculate_sensitivites = []
233
+
212
234
if model .convert_to_format != "casadi" :
213
235
# Create Jacobian from concatenated rhs and algebraic
214
236
y = pybamm .StateVector (slice (0 , model .concatenated_initial_conditions .size ))
215
237
# set up Jacobian object, for re-use of dict
216
238
jacobian = pybamm .Jacobian ()
217
- jacobian_parameters = {
218
- p : pybamm .Jacobian () for p in inputs .keys ()
219
- }
220
239
221
240
else :
222
241
# Convert model attributes to casadi
@@ -244,8 +263,11 @@ def report(string):
244
263
if model .convert_to_format == "jax" :
245
264
report (f"Converting { name } to jax" )
246
265
func = pybamm .EvaluatorJax (func )
247
- if self .sensitivity :
248
- report (f"Calculating sensitivities for { name } using jax" )
266
+ if calculate_sensitivites :
267
+ report ((
268
+ f"Calculating sensitivities for { name } with respect "
269
+ f"to parameters { calculate_sensitivites } using jax"
270
+ ))
249
271
jacp_dict = func .get_sensitivities ()
250
272
else :
251
273
jacp_dict = None
@@ -261,19 +283,24 @@ def report(string):
261
283
elif model .convert_to_format != "casadi" :
262
284
# Process with pybamm functions, optionally converting
263
285
# to python evaluator
264
- if self .sensitivity :
265
- report (f"Calculating sensitivities for { name } " )
286
+ print ('calculate_sensitivites = ' , calculate_sensitivites )
287
+ if calculate_sensitivites :
288
+ report ((
289
+ f"Calculating sensitivities for { name } with respect "
290
+ f"to parameters { calculate_sensitivites } "
291
+ ))
292
+ print (type (func ))
266
293
jacp_dict = {
267
- p : jwrtp .jac (func , pybamm .InputParameter (p ))
268
- for jwrtp , p in
269
- zip (jacobian_parameters , inputs .keys ())
294
+ p : func .diff (pybamm .InputParameter (p ))
295
+ for p in calculate_sensitivites
270
296
}
271
297
if model .convert_to_format == "python" :
272
298
report (f"Converting sensitivities for { name } to python" )
273
299
jacp_dict = {
274
300
p : pybamm .EvaluatorPython (jacp )
275
301
for p , jacp in jacp_dict .items ()
276
302
}
303
+ jacp_dict = {k : v .evaluate for k , v in jacp_dict .items ()}
277
304
else :
278
305
jacp_dict = None
279
306
@@ -306,12 +333,18 @@ def report(string):
306
333
else :
307
334
jac = None
308
335
309
- if self .sensitivity :
310
- report (f"Calculating sensitivities for { name } using CasADi" )
311
- jacp_dict = {
312
- name : casadi .jacobian (func , p )
313
- for name , p in p_casadi .items ()
314
- }
336
+ if calculate_sensitivites :
337
+ report ((
338
+ f"Calculating sensitivities for { name } with respect "
339
+ f"to parameters { calculate_sensitivites } using CasADi"
340
+ ))
341
+ jacp_dict = {}
342
+ for pname in calculate_sensitivites :
343
+ p_diff = casadi .jacobian (func , p_casadi [pname ])
344
+ jacp_dict [pname ] = casadi .Function (
345
+ name , [t_casadi , y_casadi , p_casadi_stacked ],
346
+ [p_diff ]
347
+ )
315
348
else :
316
349
jacp_dict = None
317
350
@@ -326,7 +359,12 @@ def report(string):
326
359
jac_call = SolverCallable (jac , name + "_jac" , model )
327
360
else :
328
361
jac_call = None
329
- return func , func_call , jac_call
362
+ if jacp_dict is not None :
363
+ jacp_call = {
364
+ k : SolverCallable (v , name + "_sensitivity_wrt_" + k , model )
365
+ for k , v in jacp_dict .items ()
366
+ }
367
+ return func , func_call , jac_call , jacp_call
330
368
331
369
# Check for heaviside and modulo functions in rhs and algebraic and add
332
370
# discontinuity events if these exist.
@@ -400,8 +438,8 @@ def report(string):
400
438
init_eval = InitialConditions (initial_conditions , model )
401
439
402
440
# Process rhs, algebraic and event expressions
403
- rhs , rhs_eval , jac_rhs = process (model .concatenated_rhs , "RHS" )
404
- algebraic , algebraic_eval , jac_algebraic = process (
441
+ rhs , rhs_eval , jac_rhs , jacp_rhs = process (model .concatenated_rhs , "RHS" )
442
+ algebraic , algebraic_eval , jac_algebraic , jacp_algebraic = process (
405
443
model .concatenated_algebraic , "algebraic"
406
444
)
407
445
@@ -486,19 +524,23 @@ def report(string):
486
524
# No rhs equations: residuals is algebraic only
487
525
model .residuals_eval = Residuals (algebraic , "residuals" , model )
488
526
model .jacobian_eval = jac_algebraic
527
+ model .sensitivities_eval = jacp_algebraic
489
528
elif len (model .algebraic ) == 0 :
490
529
# No algebraic equations: residuals is rhs only
491
530
model .residuals_eval = Residuals (rhs , "residuals" , model )
492
531
model .jacobian_eval = jac_rhs
532
+ model .sensitivities_eval = jacp_rhs
493
533
# Calculate consistent initial conditions for the algebraic equations
494
534
else :
495
535
all_states = pybamm .NumpyConcatenation (
496
536
model .concatenated_rhs , model .concatenated_algebraic
497
537
)
498
538
# Process again, uses caching so should be quick
499
- residuals_eval , jacobian_eval = process (all_states , "residuals" )[1 :]
539
+ residuals_eval , jacobian_eval , jacobian_wrtp_eval = \
540
+ process (all_states , "residuals" )[1 :]
500
541
model .residuals_eval = residuals_eval
501
542
model .jacobian_eval = jacobian_eval
543
+ model .sensitivities_eval = jacobian_wrtp_eval
502
544
503
545
pybamm .logger .info ("Finish solver set-up" )
504
546
@@ -589,6 +631,7 @@ def solve(
589
631
inputs = None ,
590
632
initial_conditions = None ,
591
633
nproc = None ,
634
+ calculate_sensitivities = False
592
635
):
593
636
"""
594
637
Execute the solver setup and calculate the solution of the model at
@@ -614,6 +657,10 @@ def solve(
614
657
nproc : int, optional
615
658
Number of processes to use when solving for more than one set of input
616
659
parameters. Defaults to value returned by "os.cpu_count()".
660
+ calculate_sensitivites : list of str or bool
661
+ If true, solver calculates sensitivities of all input parameters.
662
+ If only a subset of sensitivities are required, can also pass a
663
+ list of input parameter names
617
664
618
665
Returns
619
666
-------
@@ -690,7 +737,8 @@ def solve(
690
737
# not depend on input parameters. Thefore only `ext_and_inputs[0]`
691
738
# is passed to `set_up`.
692
739
# See https://github.com/pybamm-team/PyBaMM/pull/1261
693
- self .set_up (model , ext_and_inputs_list [0 ], t_eval )
740
+ self .set_up (model , ext_and_inputs_list [0 ], t_eval ,
741
+ calculate_sensitivities )
694
742
self .models_set_up .update (
695
743
{model : {"initial conditions" : model .concatenated_initial_conditions }}
696
744
)
@@ -701,7 +749,8 @@ def solve(
701
749
# If the new initial conditions are different, set up again
702
750
# Doing the whole setup again might be slow, but no need to prematurely
703
751
# optimize this
704
- self .set_up (model , ext_and_inputs_list [0 ], t_eval )
752
+ self .set_up (model , ext_and_inputs_list [0 ], t_eval ,
753
+ calculate_sensitivities )
705
754
self .models_set_up [model ][
706
755
"initial conditions"
707
756
] = model .concatenated_initial_conditions
@@ -951,6 +1000,9 @@ def step(
951
1000
save : bool
952
1001
Turn on to store the solution of all previous timesteps
953
1002
1003
+
1004
+
1005
+
954
1006
Raises
955
1007
------
956
1008
:class:`pybamm.ModelError`
@@ -1241,12 +1293,13 @@ def __init__(self, function, name, model):
1241
1293
self .timescale = self .model .timescale_eval
1242
1294
1243
1295
def __call__ (self , t , y , inputs ):
1244
- if self .name in ["RHS" , "algebraic" , "residuals" ]:
1245
- pybamm .logger .debug (
1246
- "Evaluating {} for {} at t={}" .format (
1247
- self .name , self .model .name , t * self .timescale
1248
- )
1296
+ pybamm .logger .debug (
1297
+ "Evaluating {} for {} at t={}" .format (
1298
+ self .name , self .model .name , t * self .timescale
1249
1299
)
1300
+ )
1301
+ if self .name in ["RHS" , "algebraic" , "residuals" ]:
1302
+
1250
1303
return self .function (t , y , inputs ).flatten ()
1251
1304
else :
1252
1305
return self .function (t , y , inputs )
0 commit comments