7
7
import numpy as np
8
8
from scipy import optimize
9
9
from scipy .sparse import issparse
10
+ import sys
10
11
11
12
12
13
class BaseSolver (object ):
@@ -218,13 +219,15 @@ def report(string):
218
219
)
219
220
terminate_events_eval = [
220
221
process (event .expression , "event" , use_jacobian = False )[1 ]
221
- for event in model .events
222
- if events . type == pybamm .EventType .TERMINATION
222
+ for event in model .events
223
+ if event . event_type == pybamm .EventType .TERMINATION
223
224
]
225
+
226
+ # discontinuity events are evaluated before the solver is called, so don't need
227
+ # to process them
224
228
discontinuity_events_eval = [
225
- process (event .expression , "event" , use_jacobian = False )[1 ]
226
- for event in model .events
227
- if events .type == pybamm .EventType .DISCONTINUITY
229
+ event for event in model .events
230
+ if event .event_type == pybamm .EventType .DISCONTINUITY
228
231
]
229
232
230
233
# Add the solver attributes
@@ -243,7 +246,8 @@ def report(string):
243
246
residuals , residuals_eval , jacobian_eval = process (all_states , "residuals" )
244
247
model .residuals_eval = residuals_eval
245
248
model .jacobian_eval = jacobian_eval
246
- model .y0 = self .calculate_consistent_initial_conditions (model )
249
+ y0_guess = model .concatenated_initial_conditions .flatten ()
250
+ model .y0 = self .calculate_consistent_state (model , 0 , y0_guess )
247
251
else :
248
252
# can use DAE solver to solve ODE model
249
253
model .residuals_eval = Residuals (rhs , "residuals" , model )
@@ -281,14 +285,12 @@ def set_inputs(self, model, ext_and_inputs):
281
285
model .residuals_eval .set_inputs (ext_and_inputs )
282
286
for evnt in model .terminate_events_eval :
283
287
evnt .set_inputs (ext_and_inputs )
284
- for evnt in model .discontinuity_events_eval :
285
- evnt .set_inputs (ext_and_inputs )
286
288
if model .jacobian_eval :
287
289
model .jacobian_eval .set_inputs (ext_and_inputs )
288
290
289
- def calculate_consistent_initial_conditions (self , model ):
291
+ def calculate_consistent_state (self , model , time = 0 , y0_guess = None ):
290
292
"""
291
- Calculate consistent initial conditions for the algebraic equations through
293
+ Calculate consistent state for the algebraic equations through
292
294
root-finding
293
295
294
296
Parameters
@@ -305,8 +307,9 @@ def calculate_consistent_initial_conditions(self, model):
305
307
pybamm .logger .info ("Start calculating consistent initial conditions" )
306
308
rhs = model .rhs_eval
307
309
algebraic = model .algebraic_eval
308
- y0_guess = model .concatenated_initial_conditions .flatten ()
309
310
jac = model .jac_algebraic_eval
311
+ if y0_guess is None :
312
+ y0_guess = model .concatenated_initial_conditions .flatten ()
310
313
311
314
# Split y0_guess into differential and algebraic
312
315
len_rhs = rhs (0 , y0_guess ).shape [0 ]
@@ -315,7 +318,7 @@ def calculate_consistent_initial_conditions(self, model):
315
318
def root_fun (y0_alg ):
316
319
"Evaluates algebraic using y0_diff (fixed) and y0_alg (changed by algo)"
317
320
y0 = np .concatenate ([y0_diff , y0_alg ])
318
- out = algebraic (0 , y0 )
321
+ out = algebraic (time , y0 )
319
322
pybamm .logger .debug (
320
323
"Evaluating algebraic equations at t=0, L2-norm is {}" .format (
321
324
np .linalg .norm (out )
@@ -421,13 +424,77 @@ def solve(self, model, t_eval, external_variables=None, inputs=None):
421
424
# Set inputs and external
422
425
self .set_inputs (model , ext_and_inputs )
423
426
424
- timer .reset ()
425
- pybamm .logger .info ("Calling solver" )
426
- solution = self ._integrate (model , t_eval , ext_and_inputs )
427
+ # Calculate discontinuities
428
+ discontinuities = [
429
+ event .expression .evaluate (u = inputs ) for event in model .discontinuity_events_eval
430
+ ]
431
+
432
+ # make sure they are increasing in time
433
+ discontinuities = sorted (discontinuities )
434
+ pybamm .logger .info (
435
+ 'Discontinuity events found at t = {}' .format (discontinuities )
436
+ )
437
+ # remove any identical discontinuities
438
+ discontinuities = [
439
+ v for i , v in enumerate (discontinuities )
440
+ if i == len (discontinuities )- 1 or discontinuities [i ] < discontinuities [i + 1 ]
441
+ ]
442
+
443
+ # insert time points around discontinuities in t_eval
444
+ # keep track of sub sections to integrate by storing start and end indices
445
+ start_indices = [0 ]
446
+ end_indices = []
447
+ for dtime in discontinuities :
448
+ dindex = np .searchsorted (t_eval , dtime , side = 'left' )
449
+ end_indices .append (dindex + 1 )
450
+ start_indices .append (dindex + 1 )
451
+ if t_eval [dindex ] == dtime :
452
+ t_eval [dindex ] += sys .float_info .epsilon
453
+ t_eval = np .insert (t_eval , dindex , dtime - sys .float_info .epsilon )
454
+ else :
455
+ t_eval = np .insert (t_eval , dindex ,
456
+ [dtime - sys .float_info .epsilon , dtime + sys .float_info .epsilon ])
457
+ end_indices .append (len (t_eval ))
458
+
459
+ old_y0 = model .y0
460
+ solution = None
461
+ for start_index , end_index in zip (start_indices , end_indices ):
462
+ pybamm .logger .info ("Calling solver for {} < t < {}"
463
+ .format (t_eval [start_index ], t_eval [end_index - 1 ]))
464
+ timer .reset ()
465
+ if solution is None :
466
+ solution = self ._integrate (
467
+ model , t_eval [start_index :end_index ], ext_and_inputs )
468
+ solution .solve_time = timer .time ()
469
+ else :
470
+ new_solution = self ._integrate (
471
+ model , t_eval [start_index :end_index ], ext_and_inputs )
472
+ new_solution .solve_time = timer .time ()
473
+ solution .append (new_solution , start_index = 0 )
474
+
475
+ if solution .termination != "final time" :
476
+ break
477
+
478
+ if end_index != len (t_eval ):
479
+ # setup for next integration subsection
480
+ y0_guess = solution .y [:, - 1 ]
481
+ if model .algebraic :
482
+ model .y0 = self .calculate_consistent_state (model , t_eval [end_index ], y0_guess )
483
+ else :
484
+ model .y0 = y0_guess
485
+
486
+ last_state = solution .y [:, - 1 ]
487
+ if len (model .algebraic ) > 0 :
488
+ model .y0 = self .calculate_consistent_state (
489
+ model , t_eval [end_index ], last_state )
490
+ else :
491
+ model .y0 = last_state
492
+
493
+ # restore old y0
494
+ model .y0 = old_y0
427
495
428
496
# Assign times
429
497
solution .set_up_time = set_up_time
430
- solution .solve_time = timer .time ()
431
498
432
499
# Add model and inputs to solution
433
500
solution .model = model
@@ -571,7 +638,7 @@ def get_termination_reason(self, solution, events):
571
638
final_event_values = {}
572
639
573
640
for event in events :
574
- if event .type == pybamm .EventType .TERMINATION :
641
+ if event .event_type == pybamm .EventType .TERMINATION :
575
642
final_event_values [event .name ] = abs (
576
643
event .expression .evaluate (
577
644
solution .t_event ,
0 commit comments