@@ -55,6 +55,9 @@ class CasadiSolver(pybamm.BaseSolver):
55
55
Any options to pass to the CasADi integrator when calling the integrator.
56
56
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
57
57
details.
58
+ solve_sensitivity_equations : bool, optional
59
+ Whether to explicitly formulate and solve the forward sensitivity equations.
60
+ See :class:`pybamm.BaseSolver`
58
61
59
62
"""
60
63
@@ -69,8 +72,16 @@ def __init__(
69
72
dt_max = None ,
70
73
extra_options_setup = None ,
71
74
extra_options_call = None ,
75
+ solve_sensitivity_equations = False ,
72
76
):
73
- super ().__init__ ("problem dependent" , rtol , atol , root_method , root_tol )
77
+ super ().__init__ (
78
+ "problem dependent" ,
79
+ rtol ,
80
+ atol ,
81
+ root_method ,
82
+ root_tol ,
83
+ solve_sensitivity_equations = solve_sensitivity_equations ,
84
+ )
74
85
if mode in ["safe" , "fast" ]:
75
86
self .mode = mode
76
87
else :
@@ -106,24 +117,26 @@ def _integrate(self, model, t_eval, inputs=None):
106
117
Any external variables or input parameters to pass to the model when solving
107
118
"""
108
119
# Record whether there are any symbolic inputs
109
- inputs = inputs or {}
110
- has_symbolic_inputs = any (isinstance (v , casadi .MX ) for v in inputs .values ())
120
+ inputs_dict = inputs or {}
121
+ has_symbolic_inputs = any (
122
+ isinstance (v , casadi .MX ) for v in inputs_dict .values ()
123
+ )
111
124
112
125
# convert inputs to casadi format
113
- inputs = casadi .vertcat (* [x for x in inputs .values ()])
126
+ inputs = casadi .vertcat (* [x for x in inputs_dict .values ()])
114
127
115
128
if has_symbolic_inputs :
116
129
# Create integrax`tor without grid to avoid having to create several times
117
130
self .get_integrator (model , inputs )
118
- solution = self ._run_integrator (model , model .y0 , inputs , t_eval )
131
+ solution = self ._run_integrator (model , model .y0 , inputs_dict , t_eval )
119
132
solution .termination = "final time"
120
133
return solution
121
134
elif self .mode == "fast" or not model .events :
122
135
if not model .events :
123
136
pybamm .logger .info ("No events found, running fast mode" )
124
137
# Create an integrator with the grid (we just need to do this once)
125
138
self .get_integrator (model , inputs , t_eval )
126
- solution = self ._run_integrator (model , model .y0 , inputs , t_eval )
139
+ solution = self ._run_integrator (model , model .y0 , inputs_dict , t_eval )
127
140
solution .termination = "final time"
128
141
return solution
129
142
elif self .mode == "safe" :
@@ -143,7 +156,9 @@ def _integrate(self, model, t_eval, inputs=None):
143
156
pybamm .logger .info ("Start solving {} with {}" .format (model .name , self .name ))
144
157
145
158
# Initialize solution
146
- solution = pybamm .Solution (np .array ([t ]), y0 [:, np .newaxis ])
159
+ solution = pybamm .Solution (
160
+ np .array ([t ]), y0 [:, np .newaxis ], model = model , inputs = inputs_dict
161
+ )
147
162
solution .solve_time = 0
148
163
149
164
# Try to integrate in global steps of size dt_max. Note: dt_max must
@@ -178,7 +193,7 @@ def _integrate(self, model, t_eval, inputs=None):
178
193
# halve the step size and try again.
179
194
try :
180
195
current_step_sol = self ._run_integrator (
181
- model , y0 , inputs , t_window
196
+ model , y0 , inputs_dict , t_window
182
197
)
183
198
solved = True
184
199
except pybamm .SolverError :
@@ -257,7 +272,9 @@ def event_fun(t):
257
272
t_window = np .array ([t , t_event ])
258
273
259
274
# integrator = self.get_integrator(model, t_window, inputs)
260
- current_step_sol = self ._run_integrator (model , y0 , inputs , t_window )
275
+ current_step_sol = self ._run_integrator (
276
+ model , y0 , inputs_dict , t_window
277
+ )
261
278
262
279
# assign temporary solve time
263
280
current_step_sol .solve_time = np .nan
@@ -361,10 +378,18 @@ def get_integrator(self, model, inputs, t_eval=None):
361
378
self .integrators [model ] = (integrator , use_grid )
362
379
return integrator
363
380
364
- def _run_integrator (self , model , y0 , inputs , t_eval ):
381
+ def _run_integrator (self , model , y0 , inputs_dict , t_eval ):
382
+ inputs = casadi .vertcat (* [x for x in inputs_dict .values ()])
365
383
integrator , use_grid = self .integrators [model ]
366
- y0_diff = y0 [: model .len_rhs ]
367
- y0_alg = y0 [model .len_rhs :]
384
+ # Split up initial conditions into differential and algebraic
385
+ # Check y0 to see if it includes sensitivities
386
+ if model .len_rhs_and_alg == y0 .shape [0 ]:
387
+ len_rhs = model .len_rhs
388
+ else :
389
+ len_rhs = model .len_rhs * (inputs .shape [0 ] + 1 )
390
+ y0_diff = y0 [:len_rhs ]
391
+ y0_alg = y0 [len_rhs :]
392
+ # Solve
368
393
try :
369
394
# Try solving
370
395
if use_grid is True :
@@ -379,7 +404,7 @@ def _run_integrator(self, model, y0, inputs, t_eval):
379
404
** self .extra_options_call
380
405
)
381
406
y_sol = np .concatenate ([sol ["xf" ].full (), sol ["zf" ].full ()])
382
- return pybamm .Solution (t_eval , y_sol )
407
+ return pybamm .Solution (t_eval , y_sol , model = model , inputs = inputs_dict )
383
408
else :
384
409
# Repeated calls to the integrator
385
410
x = y0_diff
@@ -399,10 +424,14 @@ def _run_integrator(self, model, y0, inputs, t_eval):
399
424
if not z .is_empty ():
400
425
y_alg = casadi .horzcat (y_alg , z )
401
426
if z .is_empty ():
402
- return pybamm .Solution (t_eval , y_diff )
427
+ return pybamm .Solution (
428
+ t_eval , y_diff , model = model , inputs = inputs_dict
429
+ )
403
430
else :
404
431
y_sol = casadi .vertcat (y_diff , y_alg )
405
- return pybamm .Solution (t_eval , y_sol )
432
+ return pybamm .Solution (
433
+ t_eval , y_sol , model = model , inputs = inputs_dict
434
+ )
406
435
except RuntimeError as e :
407
436
# If it doesn't work raise error
408
437
raise pybamm .SolverError (e .args [0 ])
0 commit comments