@@ -59,9 +59,12 @@ class CasadiSolver(pybamm.BaseSolver):
59
59
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
60
60
details.
61
61
sensitivity : bool, optional
62
- Whether to explicitly formulate and solve the forward sensitivity equations.
63
- See :class:`pybamm.BaseSolver`
62
+ Whether (and how) to calculate sensitivities when solving. Options are:
64
63
64
+ - None: no sensitivities
65
+ - "explicit forward": explicitly formulate the sensitivity equations.
66
+ See :class:`pybamm.BaseSolver`
67
+ - "casadi": use casadi to differentiate through the integrator
65
68
"""
66
69
67
70
def __init__ (
@@ -104,6 +107,7 @@ def __init__(
104
107
# Initialize
105
108
self .integrators = {}
106
109
self .integrator_specs = {}
110
+ self .y_sols = {}
107
111
108
112
pybamm .citations .register ("Andersson2019" )
109
113
@@ -122,24 +126,29 @@ def _integrate(self, model, t_eval, inputs=None):
122
126
"""
123
127
# Record whether there are any symbolic inputs
124
128
inputs_dict = inputs or {}
125
- has_symbolic_inputs = any (
126
- isinstance (v , casadi .MX ) for v in inputs_dict .values ()
127
- )
128
129
129
130
# convert inputs to casadi format
130
131
inputs = casadi .vertcat (* [x for x in inputs_dict .values ()])
131
132
132
- if has_symbolic_inputs :
133
- # Create integrator without grid to avoid having to create several times
134
- self .create_integrator (model , inputs )
135
- solution = self ._run_integrator (model , model .y0 , inputs_dict , t_eval )
133
+ if self .sensitivity == "casadi" and inputs_dict != {}:
134
+ # If the solution has already been created, we can reuse it
135
+ if model in self .y_sols :
136
+ y_sol = self .y_sols [model ]
137
+ solution = pybamm .Solution (
138
+ t_eval , y_sol , model = model , inputs = inputs_dict
139
+ )
140
+ else :
141
+ # Create integrator without grid, which will be called repeatedly
142
+ # This is necessary for casadi to compute sensitivities
143
+ self .create_integrator (model , inputs_dict )
144
+ solution = self ._run_integrator (model , model .y0 , inputs_dict , t_eval )
136
145
solution .termination = "final time"
137
146
return solution
138
147
elif self .mode == "fast" or not model .events :
139
148
if not model .events :
140
149
pybamm .logger .info ("No events found, running fast mode" )
141
150
# Create an integrator with the grid (we just need to do this once)
142
- self .create_integrator (model , inputs , t_eval )
151
+ self .create_integrator (model , inputs_dict , t_eval )
143
152
solution = self ._run_integrator (model , model .y0 , inputs_dict , t_eval )
144
153
solution .termination = "final time"
145
154
return solution
@@ -161,7 +170,7 @@ def _integrate(self, model, t_eval, inputs=None):
161
170
# in "safe without grid" mode,
162
171
# create integrator once, without grid,
163
172
# to avoid having to create several times
164
- self .create_integrator (model , inputs )
173
+ self .create_integrator (model , inputs_dict )
165
174
# Initialize solution
166
175
solution = pybamm .Solution (
167
176
np .array ([t ]), y0 [:, np .newaxis ], model = model , inputs = inputs_dict
@@ -314,12 +323,15 @@ def event_fun(t):
314
323
y0 = solution .y [:, - 1 ]
315
324
return solution
316
325
317
- def create_integrator (self , model , inputs , t_eval = None ):
326
+ def create_integrator (self , model , inputs_dict , t_eval = None ):
318
327
"""
319
328
Method to create a casadi integrator object.
320
329
If t_eval is provided, the integrator uses t_eval to make the grid.
321
330
Otherwise, the integrator has grid [0,1].
322
331
"""
332
+ # convert inputs to casadi format
333
+ inputs = casadi .vertcat (* [x for x in inputs_dict .values ()])
334
+
323
335
# Use grid if t_eval is given
324
336
use_grid = not (t_eval is None )
325
337
# Only set up problem once
@@ -400,6 +412,13 @@ def create_integrator(self, model, inputs, t_eval=None):
400
412
401
413
def _run_integrator (self , model , y0 , inputs_dict , t_eval ):
402
414
inputs = casadi .vertcat (* [x for x in inputs_dict .values ()])
415
+ symbolic_inputs = casadi .MX .sym ("inputs" , inputs .shape [0 ])
416
+ # If doing sensitivity with casadi, evaluate with symbolic inputs
417
+ # Otherwise, evaluate with actual inputs
418
+ if self .sensitivity == "casadi" :
419
+ inputs_eval = symbolic_inputs
420
+ else :
421
+ inputs_eval = inputs
403
422
integrator , use_grid = self .integrators [model ]
404
423
# Split up initial conditions into differential and algebraic
405
424
# Check y0 to see if it includes sensitivities
@@ -415,10 +434,9 @@ def _run_integrator(self, model, y0, inputs_dict, t_eval):
415
434
if use_grid is True :
416
435
# Call the integrator once, with the grid
417
436
sol = integrator (
418
- x0 = y0_diff , z0 = y0_alg , p = inputs , ** self .extra_options_call
437
+ x0 = y0_diff , z0 = y0_alg , p = inputs_eval , ** self .extra_options_call
419
438
)
420
439
y_sol = np .concatenate ([sol ["xf" ].full (), sol ["zf" ].full ()])
421
- return pybamm .Solution (t_eval , y_sol , model = model , inputs = inputs_dict )
422
440
else :
423
441
# Repeated calls to the integrator
424
442
x = y0_diff
@@ -428,7 +446,7 @@ def _run_integrator(self, model, y0, inputs_dict, t_eval):
428
446
for i in range (len (t_eval ) - 1 ):
429
447
t_min = t_eval [i ]
430
448
t_max = t_eval [i + 1 ]
431
- inputs_with_tlims = casadi .vertcat (inputs , t_min , t_max )
449
+ inputs_with_tlims = casadi .vertcat (inputs_eval , t_min , t_max )
432
450
sol = integrator (
433
451
x0 = x , z0 = z , p = inputs_with_tlims , ** self .extra_options_call
434
452
)
@@ -438,14 +456,15 @@ def _run_integrator(self, model, y0, inputs_dict, t_eval):
438
456
if not z .is_empty ():
439
457
y_alg = casadi .horzcat (y_alg , z )
440
458
if z .is_empty ():
441
- return pybamm .Solution (
442
- t_eval , y_diff , model = model , inputs = inputs_dict
443
- )
459
+ y_sol = y_diff
444
460
else :
445
461
y_sol = casadi .vertcat (y_diff , y_alg )
446
- return pybamm .Solution (
447
- t_eval , y_sol , model = model , inputs = inputs_dict
448
- )
462
+ # If doing sensitivity, return the solution as a function of the inputs
463
+ if self .sensitivity == "casadi" :
464
+ y_sol = casadi .Function ("y_sol" , [symbolic_inputs ], [y_sol ])
465
+ # Save the solution, can just reuse and change the inputs
466
+ self .y_sols [model ] = y_sol
467
+ return pybamm .Solution (t_eval , y_sol , model = model , inputs = inputs_dict )
449
468
except RuntimeError as e :
450
469
# If it doesn't work raise error
451
470
raise pybamm .SolverError (e .args [0 ])
0 commit comments