@@ -142,7 +142,9 @@ def set_up(self, model, inputs=None):
142
142
)
143
143
144
144
inputs = inputs or {}
145
- y0 = model .concatenated_initial_conditions .evaluate (0 , None , inputs = inputs )
145
+ model .y0 = model .concatenated_initial_conditions .evaluate (
146
+ 0 , None , inputs = inputs
147
+ ).flatten ()
146
148
147
149
# Set model timescale
148
150
model .timescale_eval = model .timescale .evaluate (inputs = inputs )
@@ -169,18 +171,19 @@ def set_up(self, model, inputs=None):
169
171
if model .convert_to_format != "casadi" :
170
172
simp = pybamm .Simplification ()
171
173
# Create Jacobian from concatenated rhs and algebraic
172
- y = pybamm .StateVector (slice (0 , np .size (y0 )))
174
+ y = pybamm .StateVector (slice (0 , np .size (model . y0 )))
173
175
# set up Jacobian object, for re-use of dict
174
176
jacobian = pybamm .Jacobian ()
175
177
else :
176
178
# Convert model attributes to casadi
177
179
t_casadi = casadi .MX .sym ("t" )
178
180
y_diff = casadi .MX .sym (
179
- "y_diff" , len (model .concatenated_rhs .evaluate (0 , y0 , inputs = inputs ))
181
+ "y_diff" ,
182
+ len (model .concatenated_rhs .evaluate (0 , model .y0 , inputs = inputs )),
180
183
)
181
184
y_alg = casadi .MX .sym (
182
185
"y_alg" ,
183
- len (model .concatenated_algebraic .evaluate (0 , y0 , inputs = inputs )),
186
+ len (model .concatenated_algebraic .evaluate (0 , model . y0 , inputs = inputs )),
184
187
)
185
188
y_casadi = casadi .vertcat (y_diff , y_alg )
186
189
p_casadi = {}
@@ -322,36 +325,69 @@ def report(string):
322
325
"rhs" , [t_casadi , y_casadi , p_casadi_stacked ], [explicit_rhs ]
323
326
)
324
327
model .casadi_algebraic = algebraic
325
- if self .algebraic_solver is True :
326
- # we don't calculate consistent initial conditions
327
- # for an algebraic solver as this will be the job of the algebraic solver
328
+ if len (model .rhs ) == 0 :
329
+ # No rhs equations: residuals is algebraic only
328
330
model .residuals_eval = Residuals (algebraic , "residuals" , model )
329
331
model .jacobian_eval = jac_algebraic
330
- model .y0 = y0 .flatten ()
331
332
elif len (model .algebraic ) == 0 :
332
- # can use DAE solver to solve ODE model
333
- # - no initial condition initialization needed
333
+ # No algebraic equations: residuals is rhs only
334
334
model .residuals_eval = Residuals (rhs , "residuals" , model )
335
335
model .jacobian_eval = jac_rhs
336
- model .y0 = y0 .flatten ()
337
336
# Calculate consistent initial conditions for the algebraic equations
338
337
else :
339
- if len (model .rhs ) > 0 :
340
- all_states = pybamm .NumpyConcatenation (
341
- model .concatenated_rhs , model .concatenated_algebraic
338
+ all_states = pybamm .NumpyConcatenation (
339
+ model .concatenated_rhs , model .concatenated_algebraic
340
+ )
341
+ # Process again, uses caching so should be quick
342
+ residuals_eval , jacobian_eval = process (all_states , "residuals" )[1 :]
343
+ model .residuals_eval = residuals_eval
344
+ model .jacobian_eval = jacobian_eval
345
+
346
+ pybamm .logger .info ("Finish solver set-up" )
347
+
348
+ def _set_initial_conditions (self , model , inputs , update_rhs ):
349
+ """
350
+ Set initial conditions for the model. This is skipped if the solver is an
351
+ algebraic solver (since this would make the algebraic solver redundant), and if
352
+ the model doesn't have any algebraic equations (since there are no initial
353
+ conditions to be calculated in this case).
354
+
355
+ Parameters
356
+ ----------
357
+ model : :class:`pybamm.BaseModel`
358
+ The model for which to calculate initial conditions.
359
+ inputs : dict
360
+ Any input parameters to pass to the model when solving
361
+ update_rhs : bool
362
+ Whether to update the rhs. True for 'solve', False for 'step'.
363
+
364
+ """
365
+ if self .algebraic_solver is True :
366
+ return None
367
+ elif len (model .algebraic ) == 0 :
368
+ if update_rhs is True :
369
+ # Recalculate initial conditions for the rhs equations
370
+ model .y0 = model .concatenated_initial_conditions .evaluate (
371
+ 0 , None , inputs = inputs
372
+ ).flatten ()
373
+ else :
374
+ return None
375
+ else :
376
+ if update_rhs is True :
377
+ # Recalculate initial conditions for the rhs equations
378
+ y0_from_inputs = model .concatenated_initial_conditions .evaluate (
379
+ 0 , None , inputs = inputs
380
+ ).flatten ()
381
+ # Reuse old solution for algebraic equations
382
+ y0_from_model = model .y0
383
+ len_rhs = len (
384
+ model .concatenated_rhs .evaluate (0 , model .y0 , inputs = inputs )
342
385
)
343
- # Process again, uses caching so should be quick
344
- residuals_eval , jacobian_eval = process (all_states , "residuals" )[1 :]
345
- model .residuals_eval = residuals_eval
346
- model .jacobian_eval = jacobian_eval
386
+ y0_guess = np .r_ [y0_from_inputs [:len_rhs ], y0_from_model [len_rhs :]]
347
387
else :
348
- model .residuals_eval = Residuals (algebraic , "residuals" , model )
349
- model .jacobian_eval = jac_algebraic
350
- y0_guess = y0 .flatten ()
388
+ y0_guess = model .y0
351
389
model .y0 = self .calculate_consistent_state (model , 0 , y0_guess , inputs )
352
390
353
- pybamm .logger .info ("Finish solver set-up" )
354
-
355
391
def calculate_consistent_state (self , model , time = 0 , y0_guess = None , inputs = None ):
356
392
"""
357
393
Calculate consistent state for the algebraic equations through
@@ -480,12 +516,9 @@ def jac_fn(y0_alg):
480
516
)
481
517
else :
482
518
raise pybamm .SolverError (
483
- """
484
- Could not find consistent initial conditions: solver terminated
485
- successfully, but maximum solution error ({}) above tolerance ({})
486
- """ .format (
487
- max_fun , self .root_tol
488
- )
519
+ "Could not find consistent initial conditions: solver terminated "
520
+ "successfully, but maximum solution error "
521
+ "({}) above tolerance ({})" .format (max_fun , self .root_tol )
489
522
)
490
523
491
524
def solve (self , model , t_eval = None , external_variables = None , inputs = None ):
@@ -555,6 +588,10 @@ def solve(self, model, t_eval=None, external_variables=None, inputs=None):
555
588
self .models_set_up .add (model )
556
589
else :
557
590
set_up_time = 0
591
+
592
+ # (Re-)calculate consistent initial conditions
593
+ self ._set_initial_conditions (model , ext_and_inputs , update_rhs = True )
594
+
558
595
# Non-dimensionalise time
559
596
t_eval_dimensionless = t_eval / model .timescale_eval
560
597
# Solve
@@ -758,6 +795,9 @@ def step(
758
795
model .y0 = old_solution .y [:, - 1 ]
759
796
set_up_time = 0
760
797
798
+ # (Re-)calculate consistent initial conditions
799
+ self ._set_initial_conditions (model , ext_and_inputs , update_rhs = False )
800
+
761
801
# Non-dimensionalise dt
762
802
dt_dimensionless = dt / model .timescale_eval
763
803
# Step
0 commit comments