Skip to content

Commit 6852e0e

Browse files
authoredApr 29, 2020
Merge pull request #956 from pybamm-team/issue-853-casadi-safe
Issue 853 casadi safe
2 parents cde479d + 33aff3a commit 6852e0e

File tree

9 files changed

+240
-34
lines changed

9 files changed

+240
-34
lines changed
 

‎CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010

1111
## Optimizations
1212

13+
- Changed the behaviour of "safe" mode in `CasadiSolver` ([#956](https://github.com/pybamm-team/PyBaMM/pull/956))
1314
- Sped up model building ([#927](https://github.com/pybamm-team/PyBaMM/pull/927))
1415
- Changed default solver for lead-acid to `CasadiSolver` ([#927](https://github.com/pybamm-team/PyBaMM/pull/927))
1516

1617
## Bug fixes
1718

1819
- Fixed `Interpolant` ids to allow processing ([#962](https://github.com/pybamm-team/PyBaMM/pull/962)
20+
- Fixed a bug in the initial conditions of the potential pair model ([#954](https://github.com/pybamm-team/PyBaMM/pull/954))
1921
- Changed simulation attributes to assign copies rather than the objects themselves ([#952](https://github.com/pybamm-team/PyBaMM/pull/952)
2022
- Added default values to base model so that it works with the `Simulation` class ([#952](https://github.com/pybamm-team/PyBaMM/pull/952)
2123
- Fixed solver to recompute initial conditions when inputs are changed ([#951](https://github.com/pybamm-team/PyBaMM/pull/951)

‎examples/notebooks/change-input-current.ipynb

+11-12
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
"cell_type": "markdown",
5454
"metadata": {},
5555
"source": [
56-
"We can now solve the model in the ususal way, with a 16A current"
56+
"We can now solve the model in the usual way, with a 1.6A current"
5757
]
5858
},
5959
{
@@ -66,12 +66,12 @@
6666
{
6767
"data": {
6868
"application/vnd.jupyter.widget-view+json": {
69-
"model_id": "8247adc1a3fd42dba31101bc6b501ed9",
69+
"model_id": "807f6dec221b4226aea64326d51c751b",
7070
"version_major": 2,
7171
"version_minor": 0
7272
},
7373
"text/plain": [
74-
"interactive(children=(FloatSlider(value=0.0, description='t', max=15.050167224080266, step=0.15050167224080266"
74+
"interactive(children=(FloatSlider(value=0.0, description='t', max=600.0, step=6.0), Output()), _dom_classes=('"
7575
]
7676
},
7777
"metadata": {},
@@ -92,9 +92,8 @@
9292
"\n",
9393
"# Solve the model at the given time points\n",
9494
"solver = pybamm.CasadiSolver()\n",
95-
"n = 300\n",
96-
"t_eval = np.linspace(0, 500, n)\n",
97-
"solution = solver.solve(model, t_eval, inputs={\"Current function [A]\": 16})\n",
95+
"t_eval = np.linspace(0, 600, 300)\n",
96+
"solution = solver.solve(model, t_eval, inputs={\"Current function [A]\": 1.6})\n",
9897
"\n",
9998
"# plot\n",
10099
"quick_plot = pybamm.QuickPlot(solution)\n",
@@ -116,12 +115,12 @@
116115
{
117116
"data": {
118117
"application/vnd.jupyter.widget-view+json": {
119-
"model_id": "511a6922f3f34350a543367723f09572",
118+
"model_id": "eaab57697cef4852b4fbab5ee1f7efd7",
120119
"version_major": 2,
121120
"version_minor": 0
122121
},
123122
"text/plain": [
124-
"interactive(children=(FloatSlider(value=0.0, description='t', max=500.0, step=5.0), Output()), _dom_classes=('…"
123+
"interactive(children=(FloatSlider(value=0.0, description='t', max=600.0, step=6.0), Output()), _dom_classes=('…"
125124
]
126125
},
127126
"metadata": {},
@@ -154,7 +153,7 @@
154153
{
155154
"data": {
156155
"application/vnd.jupyter.widget-view+json": {
157-
"model_id": "def542dad39b4ddebcd9a555cf684580",
156+
"model_id": "a72be467b6f84733927cb2e7a5c4b3f9",
158157
"version_major": 2,
159158
"version_minor": 0
160159
},
@@ -273,7 +272,7 @@
273272
{
274273
"data": {
275274
"application/vnd.jupyter.widget-view+json": {
276-
"model_id": "7f4df64df94d410fb9aff92003538b14",
275+
"model_id": "7dff98a2d9a44f87b71abeb0a94e16c9",
277276
"version_major": 2,
278277
"version_minor": 0
279278
},
@@ -330,9 +329,9 @@
330329
"name": "python",
331330
"nbconvert_exporter": "python",
332331
"pygments_lexer": "ipython3",
333-
"version": "3.7.3"
332+
"version": "3.6.9"
334333
}
335334
},
336335
"nbformat": 4,
337336
"nbformat_minor": 2
338-
}
337+
}

‎examples/notebooks/models/spm1.png

51.5 KB
Loading

‎examples/notebooks/models/spm2.png

1.91 KB
Loading

‎pybamm/models/submodels/current_collector/potential_pair.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,12 @@ def set_algebraic(self, variables):
6969
def set_initial_conditions(self, variables):
7070

7171
applied_current = self.param.current_with_time
72-
cc_area = self._get_effective_current_collector_area()
7372
phi_s_cn = variables["Negative current collector potential"]
7473
i_boundary_cc = variables["Current collector current density"]
7574

7675
self.initial_conditions = {
7776
phi_s_cn: pybamm.Scalar(0),
78-
i_boundary_cc: applied_current / cc_area,
77+
i_boundary_cc: applied_current,
7978
}
8079

8180

‎pybamm/solvers/base_solver.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def solve(self, model, t_eval=None, external_variables=None, inputs=None):
453453
The model whose solution to calculate. Must have attributes rhs and
454454
initial_conditions
455455
t_eval : numeric type
456-
The times at which to compute the solution
456+
The times (in seconds) at which to compute the solution
457457
external_variables : dict
458458
A dictionary of external variables and their corresponding
459459
values at the current time
@@ -491,10 +491,6 @@ def solve(self, model, t_eval=None, external_variables=None, inputs=None):
491491
# Set up external variables and inputs
492492
ext_and_inputs = self._set_up_ext_and_inputs(model, external_variables, inputs)
493493

494-
# Make sure t_eval is monotonic
495-
if (np.diff(t_eval) < 0).any():
496-
raise pybamm.SolverError("t_eval must increase monotonically")
497-
498494
# Set up
499495
timer = pybamm.Timer()
500496

@@ -511,7 +507,6 @@ def solve(self, model, t_eval=None, external_variables=None, inputs=None):
511507

512508
# Non-dimensionalise time
513509
t_eval_dimensionless = t_eval / model.timescale_eval
514-
# Solve
515510

516511
# Calculate discontinuities
517512
discontinuities = [
@@ -646,7 +641,7 @@ def step(
646641
The model whose solution to calculate. Must have attributes rhs and
647642
initial_conditions
648643
dt : numeric type
649-
The timestep over which to step the solution
644+
The timestep (in seconds) over which to step the solution
650645
npts : int, optional
651646
The number of points at which the solution will be returned during
652647
the step dt. default is 2 (returns the solution at t0 and t0 + dt).
@@ -708,10 +703,9 @@ def step(
708703

709704
# Non-dimensionalise dt
710705
dt_dimensionless = dt / model.timescale_eval
706+
711707
# Step
712708
t_eval = np.linspace(t, t + dt_dimensionless, npts)
713-
# Set inputs and external
714-
715709
pybamm.logger.info("Calling solver")
716710
timer.reset()
717711
solution = self._integrate(model, t_eval, ext_and_inputs)

‎pybamm/solvers/casadi_solver.py

+172-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import casadi
55
import pybamm
66
import numpy as np
7+
from scipy.interpolate import interp1d
8+
from scipy.optimize import brentq
79

810

911
class CasadiSolver(pybamm.BaseSolver):
@@ -22,8 +24,11 @@ class CasadiSolver(pybamm.BaseSolver):
2224
- "fast": perform direct integration, without accounting for events. \
2325
Recommended when simulating a drive cycle or other simulation where \
2426
no events should be triggered.
25-
- "safe": perform step-and-check integration, checking whether events have \
26-
been triggered. Recommended for simulations of a full charge or discharge.
27+
- "safe": perform step-and-check integration in global steps of size \
28+
dt_max, checking whether events have been triggered. Recommended for \
29+
simulations of a full charge or discharge.
30+
- "old safe": perform step-and-check integration in steps of size dt \
31+
for each dt in t_eval, checking whether events have been triggered.
2732
rtol : float, optional
2833
The relative tolerance for the solver (default is 1e-6).
2934
atol : float, optional
@@ -40,6 +45,10 @@ class CasadiSolver(pybamm.BaseSolver):
4045
max_step_decrease_counts : float, optional
4146
The maximum number of times step size can be decreased before an error is
4247
raised. Default is 5.
48+
dt_max : float, optional
49+
The maximum global step size (in seconds) used in "safe" mode. If None
50+
the default value corresponds to a non-dimensional time of 0.01
51+
(i.e. ``0.01 * model.timescale_eval``).
4352
extra_options_setup : dict, optional
4453
Any options to pass to the CasADi integrator when creating the integrator.
4554
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
@@ -59,23 +68,27 @@ def __init__(
5968
root_method="casadi",
6069
root_tol=1e-6,
6170
max_step_decrease_count=5,
71+
dt_max=None,
6272
extra_options_setup=None,
6373
extra_options_call=None,
6474
):
6575
super().__init__("problem dependent", rtol, atol, root_method, root_tol)
66-
if mode in ["safe", "fast"]:
76+
if mode in ["safe", "fast", "old safe"]:
6777
self.mode = mode
6878
else:
6979
raise ValueError(
7080
"""
71-
invalid mode '{}'. Must be either 'safe', for solving with events,
72-
or 'fast', for solving quickly without events""".format(
81+
invalid mode '{}'. Must be either 'safe' or 'old safe', for solving
82+
with events, or 'fast', for solving quickly without events""".format(
7383
mode
7484
)
7585
)
7686
self.max_step_decrease_count = max_step_decrease_count
87+
self.dt_max = dt_max
88+
7789
self.extra_options_setup = extra_options_setup or {}
7890
self.extra_options_call = extra_options_call or {}
91+
7992
self.name = "CasADi solver with '{}' mode".format(mode)
8093

8194
# Initialize
@@ -114,6 +127,149 @@ def _integrate(self, model, t_eval, inputs=None):
114127
solution.termination = "final time"
115128
return solution
116129
elif self.mode == "safe":
130+
y0 = model.y0
131+
if isinstance(y0, casadi.DM):
132+
y0 = y0.full().flatten()
133+
# Step-and-check
134+
t = t_eval[0]
135+
t_f = t_eval[-1]
136+
init_event_signs = np.sign(
137+
np.concatenate(
138+
[event(t, y0, inputs) for event in model.terminate_events_eval]
139+
)
140+
)
141+
pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))
142+
143+
# Initialize solution
144+
solution = pybamm.Solution(np.array([t]), y0[:, np.newaxis])
145+
solution.solve_time = 0
146+
147+
# Try to integrate in global steps of size dt_max. Note: dt_max must
148+
# be at least as big as the the biggest step in t_eval (multiplied
149+
# by some tolerance, here 1.01) to avoid an empty integration window below
150+
if self.dt_max:
151+
# Non-dimensionalise provided dt_max
152+
dt_max = self.dt_max / model.timescale_eval
153+
else:
154+
dt_max = 0.01
155+
dt_eval_max = np.max(np.diff(t_eval)) * 1.01
156+
dt_max = np.max([dt_max, dt_eval_max])
157+
while t < t_f:
158+
# Step
159+
solved = False
160+
count = 0
161+
dt = dt_max
162+
while not solved:
163+
# Get window of time to integrate over (so that we return
164+
# all the points in t_eval, not just t and t+dt)
165+
t_window = np.concatenate(
166+
([t], t_eval[(t_eval > t) & (t_eval < t + dt)])
167+
)
168+
# Sometimes near events the solver fails between two time
169+
# points in t_eval (i.e. no points t < t_i < t+dt for t_i
170+
# in t_eval), so we simply integrate from t to t+dt
171+
if len(t_window) == 1:
172+
t_window = np.array([t, t + dt])
173+
174+
integrator = self.get_integrator(model, t_window, inputs)
175+
# Try to solve with the current global step, if it fails then
176+
# halve the step size and try again.
177+
try:
178+
current_step_sol = self._run_integrator(
179+
integrator, model, y0, inputs, t_window
180+
)
181+
solved = True
182+
except pybamm.SolverError:
183+
dt /= 2
184+
# also reduce maximum step size for future global steps
185+
dt_max = dt
186+
count += 1
187+
if count >= self.max_step_decrease_count:
188+
raise pybamm.SolverError(
189+
"""
190+
Maximum number of decreased steps occurred at t={}. Try
191+
solving the model up to this time only or reducing dt_max.
192+
""".format(
193+
t
194+
)
195+
)
196+
# Check most recent y to see if any events have been crossed
197+
new_event_signs = np.sign(
198+
np.concatenate(
199+
[
200+
event(t, current_step_sol.y[:, -1], inputs)
201+
for event in model.terminate_events_eval
202+
]
203+
)
204+
)
205+
# Exit loop if the sign of an event changes
206+
# Locate the event time using a root finding algorithm and
207+
# event state using interpolation. The solution is then truncated
208+
# so that only the times up to the event are returned
209+
if (new_event_signs != init_event_signs).any():
210+
# get the index of the events that have been crossed
211+
event_ind = np.where(new_event_signs != init_event_signs)[0]
212+
active_events = [model.terminate_events_eval[i] for i in event_ind]
213+
214+
# create interpolant to evaluate y in the current integration
215+
# window
216+
y_sol = interp1d(current_step_sol.t, current_step_sol.y)
217+
218+
# loop over events to compute the time at which they were triggered
219+
t_events = [None] * len(active_events)
220+
for i, event in enumerate(active_events):
221+
222+
def event_fun(t):
223+
return event(t, y_sol(t), inputs)
224+
225+
if np.isnan(event_fun(current_step_sol.t[-1])[0]):
226+
# bracketed search fails if f(a) or f(b) is NaN, so we
227+
# need to find the times for which we can evaluate the event
228+
times = [
229+
t
230+
for t in current_step_sol.t
231+
if event_fun(t)[0] == event_fun(t)[0]
232+
]
233+
else:
234+
times = current_step_sol.t
235+
# skip if sign hasn't changed
236+
if np.sign(event_fun(times[0])) != np.sign(
237+
event_fun(times[-1])
238+
):
239+
t_events[i] = brentq(
240+
lambda t: event_fun(t), times[0], times[-1]
241+
)
242+
else:
243+
t_events[i] = np.nan
244+
245+
# t_event is the earliest event triggered
246+
t_event = np.nanmin(t_events)
247+
y_event = y_sol(t_event)
248+
249+
# return truncated solution
250+
t_truncated = current_step_sol.t[current_step_sol.t < t_event]
251+
y_trunctaed = current_step_sol.y[:, 0 : len(t_truncated)]
252+
truncated_step_sol = pybamm.Solution(t_truncated, y_trunctaed)
253+
# assign temporary solve time
254+
truncated_step_sol.solve_time = np.nan
255+
# append solution from the current step to solution
256+
solution.append(truncated_step_sol)
257+
258+
solution.termination = "event"
259+
solution.t_event = t_event
260+
solution.y_event = y_event
261+
break
262+
else:
263+
# assign temporary solve time
264+
current_step_sol.solve_time = np.nan
265+
# append solution from the current step to solution
266+
solution.append(current_step_sol)
267+
# update time
268+
t = t_window[-1]
269+
# update y0
270+
y0 = solution.y[:, -1]
271+
return solution
272+
elif self.mode == "old safe":
117273
y0 = model.y0
118274
if isinstance(y0, casadi.DM):
119275
y0 = y0.full().flatten()
@@ -153,7 +309,7 @@ def _integrate(self, model, t_eval, inputs=None):
153309
raise pybamm.SolverError(
154310
"""
155311
Maximum number of decreased steps occurred at t={}. Try
156-
solving the model up to this time only
312+
solving the model up to this time only.
157313
""".format(
158314
t
159315
)
@@ -182,7 +338,6 @@ def _integrate(self, model, t_eval, inputs=None):
182338
t += dt
183339
# update y0
184340
y0 = solution.y[:, -1]
185-
186341
return solution
187342

188343
def get_integrator(self, model, t_eval, inputs):
@@ -192,12 +347,22 @@ def get_integrator(self, model, t_eval, inputs):
192347
rhs = model.casadi_rhs
193348
algebraic = model.casadi_algebraic
194349

350+
# When not in DEBUG mode (level=10), suppress warnings from CasADi
351+
if (
352+
pybamm.logger.getEffectiveLevel() == 10
353+
or pybamm.settings.debug_mode is True
354+
):
355+
show_eval_warnings = True
356+
else:
357+
show_eval_warnings = False
358+
195359
options = {
196360
**self.extra_options_setup,
197361
"grid": t_eval,
198362
"reltol": self.rtol,
199363
"abstol": self.atol,
200364
"output_t0": True,
365+
"show_eval_warnings": show_eval_warnings,
201366
}
202367

203368
# set up and solve

‎tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_compare_outputs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_compare_outputs_thermal(self):
8585
solutions = []
8686
t_eval = np.linspace(0, 3600, 100)
8787
for model in models:
88-
solution = pybamm.CasadiSolver().solve(model, t_eval)
88+
solution = pybamm.CasadiSolver(dt_max=0.01).solve(model, t_eval)
8989
solutions.append(solution)
9090

9191
# compare outputs

‎tests/unit/test_solvers/test_casadi_solver.py

+50-3
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,15 @@ def test_model_solver_failure(self):
8282
disc.process_model(model)
8383

8484
solver = pybamm.CasadiSolver(extra_options_call={"regularity_check": False})
85-
85+
solver_old = pybamm.CasadiSolver(
86+
mode="old safe", extra_options_call={"regularity_check": False}
87+
)
8688
# Solve with failure at t=2
8789
t_eval = np.linspace(0, 20, 100)
8890
with self.assertRaises(pybamm.SolverError):
8991
solver.solve(model, t_eval)
92+
with self.assertRaises(pybamm.SolverError):
93+
solver_old.solve(model, t_eval)
9094
# Solve with failure at t=0
9195
model.initial_conditions = {var: 0}
9296
disc.process_model(model)
@@ -110,8 +114,36 @@ def test_model_solver_events(self):
110114
disc = get_discretisation_for_testing()
111115
disc.process_model(model)
112116

113-
# Solve
114-
solver = pybamm.CasadiSolver(rtol=1e-8, atol=1e-8)
117+
# Solve using "safe" mode
118+
solver = pybamm.CasadiSolver(mode="safe", rtol=1e-8, atol=1e-8)
119+
t_eval = np.linspace(0, 5, 100)
120+
solution = solver.solve(model, t_eval)
121+
np.testing.assert_array_less(solution.y[0], 1.5)
122+
np.testing.assert_array_less(solution.y[-1], 2.5)
123+
np.testing.assert_array_almost_equal(
124+
solution.y[0], np.exp(0.1 * solution.t), decimal=5
125+
)
126+
np.testing.assert_array_almost_equal(
127+
solution.y[-1], 2 * np.exp(0.1 * solution.t), decimal=5
128+
)
129+
130+
# Solve using "safe" mode with debug off
131+
pybamm.settings.debug_mode = False
132+
solver = pybamm.CasadiSolver(mode="safe", rtol=1e-8, atol=1e-8, dt_max=1)
133+
t_eval = np.linspace(0, 5, 100)
134+
solution = solver.solve(model, t_eval)
135+
np.testing.assert_array_less(solution.y[0], 1.5)
136+
np.testing.assert_array_less(solution.y[-1], 2.5)
137+
np.testing.assert_array_almost_equal(
138+
solution.y[0], np.exp(0.1 * solution.t), decimal=5
139+
)
140+
np.testing.assert_array_almost_equal(
141+
solution.y[-1], 2 * np.exp(0.1 * solution.t), decimal=5
142+
)
143+
pybamm.settings.debug_mode = True
144+
145+
# Solve using "old safe" mode
146+
solver = pybamm.CasadiSolver(mode="old safe", rtol=1e-8, atol=1e-8)
115147
t_eval = np.linspace(0, 5, 100)
116148
solution = solver.solve(model, t_eval)
117149
np.testing.assert_array_less(solution.y[0], 1.5)
@@ -123,6 +155,21 @@ def test_model_solver_events(self):
123155
solution.y[-1], 2 * np.exp(0.1 * solution.t), decimal=5
124156
)
125157

158+
# Test when an event returns nan
159+
model = pybamm.BaseModel()
160+
var = pybamm.Variable("var")
161+
model.rhs = {var: 0.1 * var}
162+
model.initial_conditions = {var: 1}
163+
model.events = [
164+
pybamm.Event("event", var - 1.02),
165+
pybamm.Event("sqrt event", pybamm.sqrt(1.0199 - var)),
166+
]
167+
disc = pybamm.Discretisation()
168+
disc.process_model(model)
169+
solver = pybamm.CasadiSolver(rtol=1e-8, atol=1e-8)
170+
solution = solver.solve(model, t_eval)
171+
np.testing.assert_array_less(solution.y[0], 1.02)
172+
126173
def test_model_step(self):
127174
# Create model
128175
model = pybamm.BaseModel()

0 commit comments

Comments
 (0)
Please sign in to comment.