@@ -53,12 +53,7 @@ def __init__(
53
53
inputs = None ,
54
54
):
55
55
self .t = t
56
- if isinstance (y , casadi .DM ):
57
- y = y .full ()
58
-
59
- # if inputs are None, initialize empty, to be populated later
60
- inputs = inputs or pybamm .FuzzyDict ()
61
- self .set_inputs (inputs )
56
+ self .inputs = inputs
62
57
63
58
# If the model has been provided, split up y into solution and sensitivity
64
59
# Don't do this if the sensitivity equations have not been computed (i.e. if
@@ -107,8 +102,12 @@ def __init__(
107
102
# tn_xn_p0, tn_xn_p1, ..., tn_xn_pn
108
103
# 1, Extract rhs and alg sensitivities and reshape into 3D matrices
109
104
# with shape (n_p, n_states, n_t)
110
- ode_sens = y [n_rhs :len_rhs_and_sens , :].reshape (n_p , n_rhs , n_t )
111
- alg_sens = y [len_rhs_and_sens + n_alg :, :].reshape (n_p , n_alg , n_t )
105
+ if isinstance (y , casadi .DM ):
106
+ y_full = y .full ()
107
+ else :
108
+ y_full = y
109
+ ode_sens = y_full [n_rhs :len_rhs_and_sens , :].reshape (n_p , n_rhs , n_t )
110
+ alg_sens = y_full [len_rhs_and_sens + n_alg :, :].reshape (n_p , n_alg , n_t )
112
111
# 2. Concatenate into a single 3D matrix with shape (n_p, n_states, n_t)
113
112
# i.e. along first axis
114
113
full_sens_matrix = np .concatenate ([ode_sens , alg_sens ], axis = 1 )
@@ -163,8 +162,16 @@ def y(self):
163
162
164
163
@y .setter
165
164
def y (self , y ):
166
- self ._y = y
167
- self ._y_MX = casadi .MX .sym ("y" , y .shape [0 ])
165
+ if isinstance (y , casadi .Function ):
166
+ self ._y_fn = None
167
+ inputs_stacked = casadi .vertcat (* self .inputs .values ())
168
+ self ._y = y (inputs_stacked )
169
+ self ._y_sym = y (self ._symbolic_inputs )
170
+ else :
171
+ self ._y = y
172
+ self ._y_fn = None
173
+ self ._y_sym = None
174
+ self ._y_MX = casadi .MX .sym ("y" , self ._y .shape [0 ])
168
175
169
176
@property
170
177
def model (self ):
@@ -196,8 +203,12 @@ def inputs(self):
196
203
"Values of the inputs"
197
204
return self ._inputs
198
205
199
- def set_inputs (self , inputs ):
206
+ @inputs .setter
207
+ def inputs (self , inputs ):
200
208
"Updates the input values"
209
+ # if inputs are None, initialize empty, to be populated later
210
+ inputs = inputs or pybamm .FuzzyDict ()
211
+
201
212
# self._inputs = {}
202
213
# for name, inp in inputs.items():
203
214
# # Convert number to vector of the right shape
@@ -233,13 +244,13 @@ def set_inputs(self, inputs):
233
244
inp = inp [:, np .newaxis ]
234
245
inp = np .tile (inp , len (self .t ))
235
246
self ._inputs [name ] = inp
236
- self ._all_inputs_as_MX_dict = {}
237
- for key , value in self ._inputs .items ():
238
- self ._all_inputs_as_MX_dict [key ] = casadi .MX .sym ("input" , value .shape [0 ])
247
+ self ._symbolic_inputs_dict = {
248
+ name : casadi .MX .sym (name , value .shape [0 ])
249
+ for name , value in self .inputs .items ()
250
+ }
239
251
240
- self ._all_inputs_as_MX = casadi .vertcat (
241
- * [p for p in self ._all_inputs_as_MX_dict .values ()]
242
- )
252
+ # The symbolic_inputs will be used for sensitivity
253
+ self ._symbolic_inputs = casadi .vertcat (* self ._symbolic_inputs_dict .values ())
243
254
244
255
@property
245
256
def t_event (self ):
@@ -298,12 +309,12 @@ def update(self, variables):
298
309
# Convert variable to casadi
299
310
# Make all inputs symbolic first for converting to casadi
300
311
var_sym = var_pybamm .to_casadi (
301
- self ._t_MX , self ._y_MX , inputs = self ._all_inputs_as_MX_dict
312
+ self ._t_MX , self ._y_MX , inputs = self ._symbolic_inputs_dict
302
313
)
303
314
304
315
var_casadi = casadi .Function (
305
316
"variable" ,
306
- [self ._t_MX , self ._y_MX , self ._all_inputs_as_MX ],
317
+ [self ._t_MX , self ._y_MX , self ._symbolic_inputs ],
307
318
[var_sym ],
308
319
)
309
320
self .model ._variables_casadi [key ] = var_casadi
@@ -359,8 +370,8 @@ def clear_casadi_attributes(self):
359
370
"Remove casadi objects for pickling, will be computed again automatically"
360
371
self ._t_MX = None
361
372
self ._y_MX = None
362
- self ._all_inputs_as_MX = None
363
- self ._all_inputs_as_MX_dict = None
373
+ self ._symbolic_inputs = None
374
+ self ._symbolic_inputs_dict = None
364
375
365
376
def save (self , filename ):
366
377
"""Save the whole solution using pickle"""
0 commit comments