@@ -426,7 +426,8 @@ def solve(self, model, t_eval, external_variables=None, inputs=None):
426
426
427
427
# Calculate discontinuities
428
428
discontinuities = [
429
- event .expression .evaluate (u = inputs ) for event in model .discontinuity_events_eval
429
+ event .expression .evaluate (u = inputs )
430
+ for event in model .discontinuity_events_eval
430
431
]
431
432
432
433
# make sure they are increasing in time
@@ -436,31 +437,33 @@ def solve(self, model, t_eval, external_variables=None, inputs=None):
436
437
)
437
438
# remove any identical discontinuities
438
439
discontinuities = [
439
- v for i , v in enumerate (discontinuities )
440
- if i == len (discontinuities )- 1 or discontinuities [i ] < discontinuities [i + 1 ]
441
- ]
440
+ v for i , v in enumerate (discontinuities )
441
+ if i == len (discontinuities ) - 1
442
+ or discontinuities [i ] < discontinuities [i + 1 ]
443
+ ]
442
444
443
445
# insert time points around discontinuities in t_eval
444
446
# keep track of sub sections to integrate by storing start and end indices
445
447
start_indices = [0 ]
446
448
end_indices = []
447
449
for dtime in discontinuities :
448
450
dindex = np .searchsorted (t_eval , dtime , side = 'left' )
449
- end_indices .append (dindex + 1 )
450
- start_indices .append (dindex + 1 )
451
+ end_indices .append (dindex + 1 )
452
+ start_indices .append (dindex + 1 )
451
453
if t_eval [dindex ] == dtime :
452
454
t_eval [dindex ] += sys .float_info .epsilon
453
455
t_eval = np .insert (t_eval , dindex , dtime - sys .float_info .epsilon )
454
456
else :
455
457
t_eval = np .insert (t_eval , dindex ,
456
- [dtime - sys .float_info .epsilon , dtime + sys .float_info .epsilon ])
458
+ [dtime - sys .float_info .epsilon ,
459
+ dtime + sys .float_info .epsilon ])
457
460
end_indices .append (len (t_eval ))
458
461
459
462
old_y0 = model .y0
460
463
solution = None
461
464
for start_index , end_index in zip (start_indices , end_indices ):
462
465
pybamm .logger .info ("Calling solver for {} < t < {}"
463
- .format (t_eval [start_index ], t_eval [end_index - 1 ]))
466
+ .format (t_eval [start_index ], t_eval [end_index - 1 ]))
464
467
timer .reset ()
465
468
if solution is None :
466
469
solution = self ._integrate (
@@ -479,14 +482,15 @@ def solve(self, model, t_eval, external_variables=None, inputs=None):
479
482
# setup for next integration subsection
480
483
y0_guess = solution .y [:, - 1 ]
481
484
if model .algebraic :
482
- model .y0 = self .calculate_consistent_state (model , t_eval [end_index ], y0_guess )
485
+ model .y0 = self .calculate_consistent_state (
486
+ model , t_eval [end_index ], y0_guess )
483
487
else :
484
488
model .y0 = y0_guess
485
489
486
490
last_state = solution .y [:, - 1 ]
487
491
if len (model .algebraic ) > 0 :
488
492
model .y0 = self .calculate_consistent_state (
489
- model , t_eval [end_index ], last_state )
493
+ model , t_eval [end_index ], last_state )
490
494
else :
491
495
model .y0 = last_state
492
496
0 commit comments