4
4
import casadi
5
5
import pybamm
6
6
import numpy as np
7
+ from scipy .interpolate import interp1d
8
+ from scipy .optimize import brentq
7
9
8
10
9
11
class CasadiSolver (pybamm .BaseSolver ):
@@ -22,8 +24,11 @@ class CasadiSolver(pybamm.BaseSolver):
22
24
- "fast": perform direct integration, without accounting for events. \
23
25
Recommended when simulating a drive cycle or other simulation where \
24
26
no events should be triggered.
25
- - "safe": perform step-and-check integration, checking whether events have \
26
- been triggered. Recommended for simulations of a full charge or discharge.
27
+ - "safe": perform step-and-check integration in global steps of size \
28
+ dt_max, checking whether events have been triggered. Recommended for \
29
+ simulations of a full charge or discharge.
30
+ - "old safe": perform step-and-check integration in steps of size dt \
31
+ for each dt in t_eval, checking whether events have been triggered.
27
32
rtol : float, optional
28
33
The relative tolerance for the solver (default is 1e-6).
29
34
atol : float, optional
@@ -40,6 +45,10 @@ class CasadiSolver(pybamm.BaseSolver):
40
45
max_step_decrease_counts : float, optional
41
46
The maximum number of times step size can be decreased before an error is
42
47
raised. Default is 5.
48
+ dt_max : float, optional
49
+ The maximum global step size (in seconds) used in "safe" mode. If None
50
+ the default value corresponds to a non-dimensional time of 0.01
51
+ (i.e. ``0.01 * model.timescale_eval``).
43
52
extra_options_setup : dict, optional
44
53
Any options to pass to the CasADi integrator when creating the integrator.
45
54
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
@@ -59,23 +68,27 @@ def __init__(
59
68
root_method = "casadi" ,
60
69
root_tol = 1e-6 ,
61
70
max_step_decrease_count = 5 ,
71
+ dt_max = None ,
62
72
extra_options_setup = None ,
63
73
extra_options_call = None ,
64
74
):
65
75
super ().__init__ ("problem dependent" , rtol , atol , root_method , root_tol )
66
- if mode in ["safe" , "fast" ]:
76
+ if mode in ["safe" , "fast" , "old safe" ]:
67
77
self .mode = mode
68
78
else :
69
79
raise ValueError (
70
80
"""
71
- invalid mode '{}'. Must be either 'safe', for solving with events,
72
- or 'fast', for solving quickly without events""" .format (
81
+ invalid mode '{}'. Must be either 'safe' or 'old safe' , for solving
82
+ with events, or 'fast', for solving quickly without events""" .format (
73
83
mode
74
84
)
75
85
)
76
86
self .max_step_decrease_count = max_step_decrease_count
87
+ self .dt_max = dt_max
88
+
77
89
self .extra_options_setup = extra_options_setup or {}
78
90
self .extra_options_call = extra_options_call or {}
91
+
79
92
self .name = "CasADi solver with '{}' mode" .format (mode )
80
93
81
94
# Initialize
@@ -114,6 +127,149 @@ def _integrate(self, model, t_eval, inputs=None):
114
127
solution .termination = "final time"
115
128
return solution
116
129
elif self .mode == "safe" :
130
+ y0 = model .y0
131
+ if isinstance (y0 , casadi .DM ):
132
+ y0 = y0 .full ().flatten ()
133
+ # Step-and-check
134
+ t = t_eval [0 ]
135
+ t_f = t_eval [- 1 ]
136
+ init_event_signs = np .sign (
137
+ np .concatenate (
138
+ [event (t , y0 , inputs ) for event in model .terminate_events_eval ]
139
+ )
140
+ )
141
+ pybamm .logger .info ("Start solving {} with {}" .format (model .name , self .name ))
142
+
143
+ # Initialize solution
144
+ solution = pybamm .Solution (np .array ([t ]), y0 [:, np .newaxis ])
145
+ solution .solve_time = 0
146
+
147
+ # Try to integrate in global steps of size dt_max. Note: dt_max must
148
+ # be at least as big as the the biggest step in t_eval (multiplied
149
+ # by some tolerance, here 1.01) to avoid an empty integration window below
150
+ if self .dt_max :
151
+ # Non-dimensionalise provided dt_max
152
+ dt_max = self .dt_max / model .timescale_eval
153
+ else :
154
+ dt_max = 0.01
155
+ dt_eval_max = np .max (np .diff (t_eval )) * 1.01
156
+ dt_max = np .max ([dt_max , dt_eval_max ])
157
+ while t < t_f :
158
+ # Step
159
+ solved = False
160
+ count = 0
161
+ dt = dt_max
162
+ while not solved :
163
+ # Get window of time to integrate over (so that we return
164
+ # all the points in t_eval, not just t and t+dt)
165
+ t_window = np .concatenate (
166
+ ([t ], t_eval [(t_eval > t ) & (t_eval < t + dt )])
167
+ )
168
+ # Sometimes near events the solver fails between two time
169
+ # points in t_eval (i.e. no points t < t_i < t+dt for t_i
170
+ # in t_eval), so we simply integrate from t to t+dt
171
+ if len (t_window ) == 1 :
172
+ t_window = np .array ([t , t + dt ])
173
+
174
+ integrator = self .get_integrator (model , t_window , inputs )
175
+ # Try to solve with the current global step, if it fails then
176
+ # halve the step size and try again.
177
+ try :
178
+ current_step_sol = self ._run_integrator (
179
+ integrator , model , y0 , inputs , t_window
180
+ )
181
+ solved = True
182
+ except pybamm .SolverError :
183
+ dt /= 2
184
+ # also reduce maximum step size for future global steps
185
+ dt_max = dt
186
+ count += 1
187
+ if count >= self .max_step_decrease_count :
188
+ raise pybamm .SolverError (
189
+ """
190
+ Maximum number of decreased steps occurred at t={}. Try
191
+ solving the model up to this time only or reducing dt_max.
192
+ """ .format (
193
+ t
194
+ )
195
+ )
196
+ # Check most recent y to see if any events have been crossed
197
+ new_event_signs = np .sign (
198
+ np .concatenate (
199
+ [
200
+ event (t , current_step_sol .y [:, - 1 ], inputs )
201
+ for event in model .terminate_events_eval
202
+ ]
203
+ )
204
+ )
205
+ # Exit loop if the sign of an event changes
206
+ # Locate the event time using a root finding algorithm and
207
+ # event state using interpolation. The solution is then truncated
208
+ # so that only the times up to the event are returned
209
+ if (new_event_signs != init_event_signs ).any ():
210
+ # get the index of the events that have been crossed
211
+ event_ind = np .where (new_event_signs != init_event_signs )[0 ]
212
+ active_events = [model .terminate_events_eval [i ] for i in event_ind ]
213
+
214
+ # create interpolant to evaluate y in the current integration
215
+ # window
216
+ y_sol = interp1d (current_step_sol .t , current_step_sol .y )
217
+
218
+ # loop over events to compute the time at which they were triggered
219
+ t_events = [None ] * len (active_events )
220
+ for i , event in enumerate (active_events ):
221
+
222
+ def event_fun (t ):
223
+ return event (t , y_sol (t ), inputs )
224
+
225
+ if np .isnan (event_fun (current_step_sol .t [- 1 ])[0 ]):
226
+ # bracketed search fails if f(a) or f(b) is NaN, so we
227
+ # need to find the times for which we can evaluate the event
228
+ times = [
229
+ t
230
+ for t in current_step_sol .t
231
+ if event_fun (t )[0 ] == event_fun (t )[0 ]
232
+ ]
233
+ else :
234
+ times = current_step_sol .t
235
+ # skip if sign hasn't changed
236
+ if np .sign (event_fun (times [0 ])) != np .sign (
237
+ event_fun (times [- 1 ])
238
+ ):
239
+ t_events [i ] = brentq (
240
+ lambda t : event_fun (t ), times [0 ], times [- 1 ]
241
+ )
242
+ else :
243
+ t_events [i ] = np .nan
244
+
245
+ # t_event is the earliest event triggered
246
+ t_event = np .nanmin (t_events )
247
+ y_event = y_sol (t_event )
248
+
249
+ # return truncated solution
250
+ t_truncated = current_step_sol .t [current_step_sol .t < t_event ]
251
+ y_trunctaed = current_step_sol .y [:, 0 : len (t_truncated )]
252
+ truncated_step_sol = pybamm .Solution (t_truncated , y_trunctaed )
253
+ # assign temporary solve time
254
+ truncated_step_sol .solve_time = np .nan
255
+ # append solution from the current step to solution
256
+ solution .append (truncated_step_sol )
257
+
258
+ solution .termination = "event"
259
+ solution .t_event = t_event
260
+ solution .y_event = y_event
261
+ break
262
+ else :
263
+ # assign temporary solve time
264
+ current_step_sol .solve_time = np .nan
265
+ # append solution from the current step to solution
266
+ solution .append (current_step_sol )
267
+ # update time
268
+ t = t_window [- 1 ]
269
+ # update y0
270
+ y0 = solution .y [:, - 1 ]
271
+ return solution
272
+ elif self .mode == "old safe" :
117
273
y0 = model .y0
118
274
if isinstance (y0 , casadi .DM ):
119
275
y0 = y0 .full ().flatten ()
@@ -153,7 +309,7 @@ def _integrate(self, model, t_eval, inputs=None):
153
309
raise pybamm .SolverError (
154
310
"""
155
311
Maximum number of decreased steps occurred at t={}. Try
156
- solving the model up to this time only
312
+ solving the model up to this time only.
157
313
""" .format (
158
314
t
159
315
)
@@ -182,7 +338,6 @@ def _integrate(self, model, t_eval, inputs=None):
182
338
t += dt
183
339
# update y0
184
340
y0 = solution .y [:, - 1 ]
185
-
186
341
return solution
187
342
188
343
def get_integrator (self , model , t_eval , inputs ):
@@ -192,12 +347,22 @@ def get_integrator(self, model, t_eval, inputs):
192
347
rhs = model .casadi_rhs
193
348
algebraic = model .casadi_algebraic
194
349
350
+ # When not in DEBUG mode (level=10), suppress warnings from CasADi
351
+ if (
352
+ pybamm .logger .getEffectiveLevel () == 10
353
+ or pybamm .settings .debug_mode is True
354
+ ):
355
+ show_eval_warnings = True
356
+ else :
357
+ show_eval_warnings = False
358
+
195
359
options = {
196
360
** self .extra_options_setup ,
197
361
"grid" : t_eval ,
198
362
"reltol" : self .rtol ,
199
363
"abstol" : self .atol ,
200
364
"output_t0" : True ,
365
+ "show_eval_warnings" : show_eval_warnings ,
201
366
}
202
367
203
368
# set up and solve
0 commit comments