@@ -216,16 +216,20 @@ def report(string):
216
216
# Check for heaviside functions in rhs and algebraic and add discontinuity
217
217
# events if these exist.
218
218
# Note: only checks for the case of t < X, t <= X, X < t, or X <= t
219
- for symbol in itertools .chain (model .concatenated_rhs .pre_order (),
220
- model .concatenated_algebraic .pre_order ()):
219
+ for symbol in itertools .chain (
220
+ model .concatenated_rhs .pre_order (), model .concatenated_algebraic .pre_order ()
221
+ ):
221
222
if isinstance (symbol , pybamm .Heaviside ):
222
223
if symbol .right .id == pybamm .t .id :
223
224
expr = symbol .left
224
225
elif symbol .left .id == pybamm .t .id :
225
226
expr = symbol .right
226
227
227
- model .events .append (pybamm .Event (str (symbol ), expr .new_copy (),
228
- pybamm .EventType .DISCONTINUITY ))
228
+ model .events .append (
229
+ pybamm .Event (
230
+ str (symbol ), expr .new_copy (), pybamm .EventType .DISCONTINUITY
231
+ )
232
+ )
229
233
230
234
# Process rhs, algebraic and event expressions
231
235
rhs , rhs_eval , jac_rhs = process (model .concatenated_rhs , "RHS" )
@@ -241,7 +245,8 @@ def report(string):
241
245
# discontinuity events are evaluated before the solver is called, so don't need
242
246
# to process them
243
247
discontinuity_events_eval = [
244
- event for event in model .events
248
+ event
249
+ for event in model .events
245
250
if event .event_type == pybamm .EventType .DISCONTINUITY
246
251
]
247
252
@@ -448,11 +453,12 @@ def solve(self, model, t_eval, external_variables=None, inputs=None):
448
453
# make sure they are increasing in time
449
454
discontinuities = sorted (discontinuities )
450
455
pybamm .logger .info (
451
- ' Discontinuity events found at t = {}' .format (discontinuities )
456
+ " Discontinuity events found at t = {}" .format (discontinuities )
452
457
)
453
458
# remove any identical discontinuities
454
459
discontinuities = [
455
- v for i , v in enumerate (discontinuities )
460
+ v
461
+ for i , v in enumerate (discontinuities )
456
462
if i == len (discontinuities ) - 1
457
463
or discontinuities [i ] < discontinuities [i + 1 ]
458
464
]
@@ -462,16 +468,18 @@ def solve(self, model, t_eval, external_variables=None, inputs=None):
462
468
start_indices = [0 ]
463
469
end_indices = []
464
470
for dtime in discontinuities :
465
- dindex = np .searchsorted (t_eval , dtime , side = ' left' )
471
+ dindex = np .searchsorted (t_eval , dtime , side = " left" )
466
472
end_indices .append (dindex + 1 )
467
473
start_indices .append (dindex + 1 )
468
474
if t_eval [dindex ] == dtime :
469
475
t_eval [dindex ] += sys .float_info .epsilon
470
476
t_eval = np .insert (t_eval , dindex , dtime - sys .float_info .epsilon )
471
477
else :
472
- t_eval = np .insert (t_eval , dindex ,
473
- [dtime - sys .float_info .epsilon ,
474
- dtime + sys .float_info .epsilon ])
478
+ t_eval = np .insert (
479
+ t_eval ,
480
+ dindex ,
481
+ [dtime - sys .float_info .epsilon , dtime + sys .float_info .epsilon ],
482
+ )
475
483
end_indices .append (len (t_eval ))
476
484
477
485
# integrate separatly over each time segment and accumulate into the solution
@@ -480,16 +488,21 @@ def solve(self, model, t_eval, external_variables=None, inputs=None):
480
488
old_y0 = model .y0
481
489
solution = None
482
490
for start_index , end_index in zip (start_indices , end_indices ):
483
- pybamm .logger .info ("Calling solver for {} < t < {}"
484
- .format (t_eval [start_index ], t_eval [end_index - 1 ]))
491
+ pybamm .logger .info (
492
+ "Calling solver for {} < t < {}" .format (
493
+ t_eval [start_index ], t_eval [end_index - 1 ]
494
+ )
495
+ )
485
496
timer .reset ()
486
497
if solution is None :
487
498
solution = self ._integrate (
488
- model , t_eval [start_index :end_index ], ext_and_inputs )
499
+ model , t_eval [start_index :end_index ], ext_and_inputs
500
+ )
489
501
solution .solve_time = timer .time ()
490
502
else :
491
503
new_solution = self ._integrate (
492
- model , t_eval [start_index :end_index ], ext_and_inputs )
504
+ model , t_eval [start_index :end_index ], ext_and_inputs
505
+ )
493
506
new_solution .solve_time = timer .time ()
494
507
solution .append (new_solution , start_index = 0 )
495
508
@@ -501,14 +514,16 @@ def solve(self, model, t_eval, external_variables=None, inputs=None):
501
514
y0_guess = solution .y [:, - 1 ]
502
515
if model .algebraic :
503
516
model .y0 = self .calculate_consistent_state (
504
- model , t_eval [end_index ], y0_guess )
517
+ model , t_eval [end_index ], y0_guess
518
+ )
505
519
else :
506
520
model .y0 = y0_guess
507
521
508
522
last_state = solution .y [:, - 1 ]
509
523
if len (model .algebraic ) > 0 :
510
524
model .y0 = self .calculate_consistent_state (
511
- model , t_eval [end_index ], last_state )
525
+ model , t_eval [end_index ], last_state
526
+ )
512
527
else :
513
528
model .y0 = last_state
514
529
0 commit comments