@@ -42,11 +42,7 @@ class BaseSolver(object):
42
42
the solution instance returned. At the moment this is only implemented for the
43
43
IDAKLU solver.\
44
44
- "explicit forward": explicitly formulate the sensitivity equations for
45
- the chosen input parameters. The formulation is as per
46
- "Park, S., Kato, D., Gima, Z., Klein, R., & Moura, S. (2018).\
47
- Optimal experimental design for parameterization of an electrochemical
48
- lithium-ion battery model. Journal of The Electrochemical\
49
- Society, 165(7), A1309.". See #1100 for details. At the moment this is only
45
+ the chosen input parameters. . At the moment this is only
50
46
implemented using convert_to_format = 'casadi'. \
51
47
- see individual solvers for other options
52
48
"""
@@ -60,7 +56,6 @@ def __init__(
60
56
root_tol = 1e-6 ,
61
57
extrap_tol = 0 ,
62
58
max_steps = "deprecated" ,
63
- sensitivity = None ,
64
59
):
65
60
self ._method = method
66
61
self ._rtol = rtol
@@ -79,7 +74,6 @@ def __init__(
79
74
self .name = "Base solver"
80
75
self .ode_solver = False
81
76
self .algebraic_solver = False
82
- self .sensitivity = sensitivity
83
77
84
78
@property
85
79
def method (self ):
@@ -140,8 +134,6 @@ def copy(self):
140
134
new_solver .models_set_up = {}
141
135
return new_solver
142
136
143
-
144
-
145
137
def set_up (self , model , inputs = None , t_eval = None ,
146
138
calculate_sensitivites = False ):
147
139
"""Unpack model, perform checks, and calculate jacobian.
@@ -238,21 +230,17 @@ def set_up(self, model, inputs=None, t_eval=None,
238
230
calculate_sensitivites = [p for p in inputs .keys ()]
239
231
else :
240
232
calculate_sensitivites = []
233
+
234
+ calculate_sensitivites_explicit = False
235
+ if calculate_sensitivites and not isinstance (self , pybamm .IDAKLUSolver ):
236
+ calculate_sensitivites_explicit = True
237
+
241
238
# save sensitivity parameters so we can identify them later on
242
239
# (FYI: this is used in the Solution class)
243
240
model .calculate_sensitivities = calculate_sensitivites
244
- model .len_rhs_sens = model .len_rhs * len (calculate_sensitivites )
245
- model .len_alg_sens = model .len_alg * len (calculate_sensitivites )
246
-
247
- # Only allow solving explicit sensitivity equations with the casadi format for now
248
- if (
249
- self .sensitivity == "explicit forward"
250
- and model .convert_to_format != "casadi"
251
- ):
252
- raise NotImplementedError (
253
- "model should be converted to casadi format in order to solve "
254
- "explicit sensitivity equations"
255
- )
241
+ if calculate_sensitivites_explicit :
242
+ model .len_rhs_sens = model .len_rhs * len (calculate_sensitivites )
243
+ model .len_alg_sens = model .len_alg * len (calculate_sensitivites )
256
244
257
245
if model .convert_to_format != "casadi" :
258
246
# Create Jacobian from concatenated rhs and algebraic
@@ -275,7 +263,7 @@ def set_up(self, model, inputs=None, t_eval=None,
275
263
p_casadi [name ] = casadi .MX .sym (name , value .shape [0 ])
276
264
p_casadi_stacked = casadi .vertcat (* [p for p in p_casadi .values ()])
277
265
# sensitivity vectors
278
- if self . sensitivity == "explicit forward" :
266
+ if calculate_sensitivites_explicit :
279
267
pS_casadi_stacked = casadi .vertcat (
280
268
* [p_casadi [name ] for name in calculate_sensitivites ]
281
269
)
@@ -297,15 +285,19 @@ def report(string):
297
285
if model .convert_to_format == "jax" :
298
286
report (f"Converting { name } to jax" )
299
287
func = pybamm .EvaluatorJax (func )
300
- if calculate_sensitivites :
288
+ jacp = None
289
+ if calculate_sensitivites_explicit :
290
+ raise NotImplementedError (
291
+ "sensitivities using convert_to_format = 'jax' "
292
+ "only implemented for IDAKLUSolver"
293
+ )
294
+ elif calculate_sensitivites :
301
295
report ((
302
296
f"Calculating sensitivities for { name } with respect "
303
297
f"to parameters { calculate_sensitivites } using jax"
304
298
))
305
299
jacp = func .get_sensitivities ()
306
300
jacp = jacp .evaluate
307
- else :
308
- jacp = None
309
301
if use_jacobian :
310
302
report (f"Calculating jacobian for { name } using jax" )
311
303
jac = func .get_jacobian ()
@@ -319,6 +311,10 @@ def report(string):
319
311
# Process with pybamm functions, optionally converting
320
312
# to python evaluator
321
313
if calculate_sensitivites :
314
+ raise NotImplementedError (
315
+ "sensitivities only implemented with "
316
+ "convert_to_format = 'casadi' or convert_to_format = 'jax'"
317
+ )
322
318
report ((
323
319
f"Calculating sensitivities for { name } with respect "
324
320
f"to parameters { calculate_sensitivites } "
@@ -362,9 +358,16 @@ def jacp(*args, **kwargs):
362
358
report (f"Converting { name } to CasADi" )
363
359
func = func .to_casadi (t_casadi , y_casadi , inputs = p_casadi )
364
360
# Add sensitivity vectors to the rhs and algebraic equations
365
- if self .sensitivity == "explicit forward" :
361
+ jacp = None
362
+ if calculate_sensitivites_explicit :
363
+ # The formulation is as per Park, S., Kato, D., Gima, Z., Klein, R.,
364
+ # & Moura, S. (2018). Optimal experimental design for
365
+ # parameterization of an electrochemical lithium-ion battery model.
366
+ # Journal of The Electrochemical Society, 165(7), A1309.". See #1100
367
+ # for details
366
368
if name == "rhs" and model .len_rhs > 0 :
367
- report ("Creating sensitivity equations for rhs using CasADi" )
369
+ report (
370
+ "Creating explicit forward sensitivity equations for rhs using CasADi" )
368
371
df_dx = casadi .jacobian (func , y_diff )
369
372
df_dp = casadi .jacobian (func , pS_casadi_stacked )
370
373
S_x_mat = S_x .reshape (
@@ -383,7 +386,7 @@ def jacp(*args, **kwargs):
383
386
func = casadi .vertcat (func , S_rhs )
384
387
if name == "algebraic" and model .len_alg > 0 :
385
388
report (
386
- "Creating sensitivity equations for algebraic using CasADi"
389
+ "Creating explicit forward sensitivity equations for algebraic using CasADi"
387
390
)
388
391
dg_dz = casadi .jacobian (func , y_alg )
389
392
dg_dp = casadi .jacobian (func , pS_casadi_stacked )
@@ -401,7 +404,12 @@ def jacp(*args, **kwargs):
401
404
(- 1 , 1 )
402
405
)
403
406
func = casadi .vertcat (func , S_alg )
404
- elif name == "initial_conditions" :
407
+ if name == "residuals" :
408
+ raise NotImplementedError (
409
+ "explicit forward equations not implimented for residuals"
410
+ )
411
+
412
+ if name == "initial_conditions" :
405
413
if model .len_rhs == 0 or model .len_alg == 0 :
406
414
S_0 = casadi .jacobian (func , pS_casadi_stacked ).reshape (
407
415
(- 1 , 1 )
@@ -417,16 +425,7 @@ def jacp(*args, **kwargs):
417
425
(- 1 , 1 )
418
426
)
419
427
func = casadi .vertcat (x0 , Sx_0 , z0 , Sz_0 )
420
- if use_jacobian :
421
- report (f"Calculating jacobian for { name } using CasADi" )
422
- jac_casadi = casadi .jacobian (func , y_and_S )
423
- jac = casadi .Function (
424
- name , [t_casadi , y_and_S , p_casadi_stacked ], [jac_casadi ]
425
- )
426
- else :
427
- jac = None
428
-
429
- if calculate_sensitivites and self .sensitivity != "explicit forward" :
428
+ elif calculate_sensitivites :
430
429
report ((
431
430
f"Calculating sensitivities for { name } with respect "
432
431
f"to parameters { calculate_sensitivites } using CasADi"
@@ -444,8 +443,14 @@ def jacp(*args, **kwargs):
444
443
return {k : v (* args , ** kwargs )
445
444
for k , v in jacp_dict .items ()}
446
445
446
+ if use_jacobian :
447
+ report (f"Calculating jacobian for { name } using CasADi" )
448
+ jac_casadi = casadi .jacobian (func , y_and_S )
449
+ jac = casadi .Function (
450
+ name , [t_casadi , y_and_S , p_casadi_stacked ], [jac_casadi ]
451
+ )
447
452
else :
448
- jacp = None
453
+ jac = None
449
454
450
455
func = casadi .Function (
451
456
name , [t_casadi , y_and_S , p_casadi_stacked ], [func ]
@@ -538,7 +543,7 @@ def jacp(*args, **kwargs):
538
543
)[0 ]
539
544
init_eval = InitialConditions (initial_conditions , model )
540
545
541
- if self . sensitivity == "explicit forward" :
546
+ if calculate_sensitivites_explicit :
542
547
y0_total_size = (
543
548
model .len_rhs + model .len_rhs_sens
544
549
+ model .len_alg + model .len_alg_sens
@@ -555,7 +560,6 @@ def jacp(*args, **kwargs):
555
560
556
561
# Calculate initial conditions
557
562
model .y0 = init_eval (inputs )
558
- print ('YYYYY' , model .y0 )
559
563
560
564
casadi_terminate_events = []
561
565
terminate_events_eval = []
@@ -726,7 +730,6 @@ def _set_initial_conditions(self, model, inputs, update_rhs):
726
730
model .y0 = casadi .Function ("y0" , [symbolic_inputs ], [y0 ])
727
731
else :
728
732
model .y0 = y0
729
- print ('ASDF' , model .y0 )
730
733
731
734
def calculate_consistent_state (self , model , time = 0 , inputs = None ):
732
735
"""
0 commit comments