@@ -68,7 +68,7 @@ def __init__(
68
68
self .extra_options = extra_options
69
69
self .name = "CasADi solver ({}) with '{}' mode" .format (method , mode )
70
70
71
- def solve (self , model , t_eval , inputs = None ):
71
+ def solve (self , model , t_eval , external_variables = None , inputs = None ):
72
72
"""
73
73
Execute the solver setup and calculate the solution of the model at
74
74
specified times.
@@ -80,6 +80,9 @@ def solve(self, model, t_eval, inputs=None):
80
80
initial_conditions
81
81
t_eval : numeric type
82
82
The times at which to compute the solution
83
+ external_variables : dict
84
+ A dictionary of external variables and their corresponding
85
+ values at the current time
83
86
inputs : dict, optional
84
87
Any input parameters to pass to the model when solving
85
88
@@ -93,11 +96,15 @@ def solve(self, model, t_eval, inputs=None):
93
96
"""
94
97
if self .mode == "fast" :
95
98
# Solve model normally by calling the solve method from parent class
96
- return super ().solve (model , t_eval , inputs = inputs )
99
+ return super ().solve (
100
+ model , t_eval , external_variables = external_variables , inputs = inputs
101
+ )
97
102
elif model .events == {}:
98
103
pybamm .logger .info ("No events found, running fast mode" )
99
104
# Solve model normally by calling the solve method from parent class
100
- return super ().solve (model , t_eval , inputs = inputs )
105
+ return super ().solve (
106
+ model , t_eval , external_variables = external_variables , inputs = inputs
107
+ )
101
108
elif self .mode == "safe" :
102
109
# Step-and-check
103
110
timer = pybamm .Timer ()
@@ -122,7 +129,12 @@ def solve(self, model, t_eval, inputs=None):
122
129
# different to t_eval, but shouldn't matter too much as it should
123
130
# only happen near events.
124
131
try :
125
- current_step_sol = self .step (model , dt , inputs = inputs )
132
+ current_step_sol = self .step (
133
+ model ,
134
+ dt ,
135
+ external_variables = external_variables ,
136
+ inputs = inputs ,
137
+ )
126
138
solved = True
127
139
except pybamm .SolverError :
128
140
dt /= 2
@@ -229,6 +241,11 @@ def integrate_casadi(self, rhs, algebraic, y0, t_eval, inputs=None):
229
241
Any input parameters to pass to the model when solving
230
242
"""
231
243
inputs = inputs or {}
244
+ if self .y_ext is None :
245
+ y_ext = np .array ([])
246
+ else :
247
+ y_ext = self .y_ext
248
+
232
249
options = {
233
250
"grid" : t_eval ,
234
251
"reltol" : self .rtol ,
@@ -242,16 +259,15 @@ def integrate_casadi(self, rhs, algebraic, y0, t_eval, inputs=None):
242
259
# set up and solve
243
260
t = casadi .MX .sym ("t" )
244
261
u = casadi .vertcat (* [x for x in inputs .values ()])
245
- y_diff = self .y_diff
262
+ y0_w_ext = casadi .vertcat (y0 , y_ext [len (y0 ) :])
263
+ y_diff = casadi .MX .sym ("y_diff" , rhs (0 , y0_w_ext , u ).shape [0 ])
246
264
problem = {"t" : t , "x" : y_diff }
247
265
if algebraic is None :
248
- y_casadi_w_ext = casadi .vertcat (y_diff , self . y_ext [y_diff . shape [ 0 ] :])
266
+ y_casadi_w_ext = casadi .vertcat (y_diff , y_ext [len ( y0 ) :])
249
267
problem .update ({"ode" : rhs (t , y_casadi_w_ext , u )})
250
268
else :
251
269
y_alg = self .y_alg
252
- y_casadi_w_ext = casadi .vertcat (
253
- y_diff , y_alg , self .y_ext [y_diff .shape [0 ] + y_alg .shape [0 ] :]
254
- )
270
+ y_casadi_w_ext = casadi .vertcat (y_diff , y_alg , y_ext [len (y0 ) :])
255
271
problem .update (
256
272
{
257
273
"z" : y_alg ,
0 commit comments