@@ -40,24 +40,25 @@ class ProcessedVariable(object):
40
40
variable. Note that this can be any kind of node in the expression tree, not
41
41
just a :class:`pybamm.Variable`.
42
42
When evaluated, returns an array of size (m,n)
43
+ base_variable_casadi : :class:`casadi.Function`
44
+ A casadi function. When evaluated, returns the same thing as
45
+ `base_Variable.evaluate` (but more efficiently).
43
46
solution : :class:`pybamm.Solution`
44
47
The solution object to be used to create the processed variables
45
- known_evals : dict
46
- Dictionary of known evaluations, to be used to speed up finding the solution
47
48
warn : bool, optional
48
49
Whether to raise warnings when trying to evaluate time and length scales.
49
50
Default is True.
50
51
"""
51
52
52
- def __init__ (self , base_variable , solution , known_evals = None , warn = True ):
53
+ def __init__ (self , base_variable , base_variable_casadi , solution , warn = True ):
53
54
self .base_variable = base_variable
55
+ self .base_variable_casadi = base_variable_casadi
54
56
self .t_sol = solution .t
55
57
self .u_sol = solution .y
56
58
self .mesh = base_variable .mesh
57
59
self .inputs = solution .inputs
58
60
self .domain = base_variable .domain
59
61
self .auxiliary_domains = base_variable .auxiliary_domains
60
- self .known_evals = known_evals
61
62
self .warn = warn
62
63
63
64
# Sensitivity starts off uninitialized, only set when called
@@ -104,19 +105,10 @@ def __init__(self, base_variable, solution, known_evals=None, warn=True):
104
105
self .length_scales = solution .length_scales_eval
105
106
106
107
# Evaluate base variable at initial time
107
- if self .known_evals :
108
- self .base_eval , self .known_evals [solution .t [0 ]] = base_variable .evaluate (
109
- self .t_sol [0 ],
110
- self .u_sol [:, 0 ],
111
- inputs = {name : inp [:, 0 ] for name , inp in solution .inputs .items ()},
112
- known_evals = self .known_evals [solution .t [0 ]],
113
- )
114
- else :
115
- self .base_eval = base_variable .evaluate (
116
- solution .t [0 ],
117
- solution .y [:, 0 ],
118
- inputs = {name : inp [:, 0 ] for name , inp in solution .inputs .items ()},
119
- )
108
+ inputs = casadi .vertcat (* [inp [:, 0 ] for inp in self .inputs .values ()])
109
+ self .base_eval = self .base_variable_casadi (
110
+ solution .t [0 ], solution .y [:, 0 ], inputs
111
+ ).full ()
120
112
121
113
# handle 2D (in space) finite element variables differently
122
114
if (
@@ -164,13 +156,8 @@ def initialise_0D(self):
164
156
for idx in range (len (self .t_sol )):
165
157
t = self .t_sol [idx ]
166
158
u = self .u_sol [:, idx ]
167
- inputs = {name : inp [:, idx ] for name , inp in self .inputs .items ()}
168
- if self .known_evals :
169
- entries [idx ], self .known_evals [t ] = self .base_variable .evaluate (
170
- t , u , inputs = inputs , known_evals = self .known_evals [t ]
171
- )
172
- else :
173
- entries [idx ] = self .base_variable .evaluate (t , u , inputs = inputs )
159
+ inputs = casadi .vertcat (* [inp [:, idx ] for inp in self .inputs .values ()])
160
+ entries [idx ] = self .base_variable_casadi (t , u , inputs ).full ()[0 , 0 ]
174
161
175
162
# set up interpolation
176
163
if len (self .t_sol ) == 1 :
@@ -200,15 +187,8 @@ def initialise_1D(self, fixed_t=False):
200
187
for idx in range (len (self .t_sol )):
201
188
t = self .t_sol [idx ]
202
189
u = self .u_sol [:, idx ]
203
- inputs = {name : inp [:, idx ] for name , inp in self .inputs .items ()}
204
- if self .known_evals :
205
- eval_and_known_evals = self .base_variable .evaluate (
206
- t , u , inputs = inputs , known_evals = self .known_evals [t ]
207
- )
208
- entries [:, idx ] = eval_and_known_evals [0 ][:, 0 ]
209
- self .known_evals [t ] = eval_and_known_evals [1 ]
210
- else :
211
- entries [:, idx ] = self .base_variable .evaluate (t , u , inputs = inputs )[:, 0 ]
190
+ inputs = casadi .vertcat (* [inp [:, idx ] for inp in self .inputs .values ()])
191
+ entries [:, idx ] = self .base_variable_casadi (t , u , inputs ).full ()[:, 0 ]
212
192
213
193
# Get node and edge values
214
194
nodes = self .mesh .nodes
@@ -310,23 +290,12 @@ def initialise_2D(self):
310
290
for idx in range (len (self .t_sol )):
311
291
t = self .t_sol [idx ]
312
292
u = self .u_sol [:, idx ]
313
- inputs = {name : inp [:, idx ] for name , inp in self .inputs .items ()}
314
- if self .known_evals :
315
- eval_and_known_evals = self .base_variable .evaluate (
316
- t , u , inputs = inputs , known_evals = self .known_evals [t ]
317
- )
318
- entries [:, :, idx ] = np .reshape (
319
- eval_and_known_evals [0 ],
320
- [first_dim_size , second_dim_size ],
321
- order = "F" ,
322
- )
323
- self .known_evals [t ] = eval_and_known_evals [1 ]
324
- else :
325
- entries [:, :, idx ] = np .reshape (
326
- self .base_variable .evaluate (t , u , inputs = inputs ),
327
- [first_dim_size , second_dim_size ],
328
- order = "F" ,
329
- )
293
+ inputs = casadi .vertcat (* [inp [:, idx ] for inp in self .inputs .values ()])
294
+ entries [:, :, idx ] = np .reshape (
295
+ self .base_variable_casadi (t , u , inputs ).full (),
296
+ [first_dim_size , second_dim_size ],
297
+ order = "F" ,
298
+ )
330
299
331
300
# add points outside first dimension domain for extrapolation to
332
301
# boundaries
@@ -463,22 +432,13 @@ def initialise_2D_scikit_fem(self):
463
432
for idx in range (len (self .t_sol )):
464
433
t = self .t_sol [idx ]
465
434
u = self .u_sol [:, idx ]
466
- inputs = { name : inp [:, idx ] for name , inp in self .inputs .items ()}
435
+ inputs = casadi . vertcat ( * [ inp [:, idx ] for inp in self .inputs .values ()])
467
436
468
- if self .known_evals :
469
- eval_and_known_evals = self .base_variable .evaluate (
470
- t , u , inputs = inputs , known_evals = self .known_evals [t ]
471
- )
472
- entries [:, :, idx ] = np .reshape (
473
- eval_and_known_evals [0 ], [len_y , len_z ], order = "F"
474
- )
475
- self .known_evals [t ] = eval_and_known_evals [1 ]
476
- else :
477
- entries [:, :, idx ] = np .reshape (
478
- self .base_variable .evaluate (t , u , inputs = inputs ),
479
- [len_y , len_z ],
480
- order = "F" ,
481
- )
437
+ entries [:, :, idx ] = np .reshape (
438
+ self .base_variable_casadi (t , u , inputs ).full (),
439
+ [len_y , len_z ],
440
+ order = "F" ,
441
+ )
482
442
483
443
# assign attributes for reference
484
444
self .entries = entries
0 commit comments