Skip to content

Commit faa3167

Browse files
Merge pull request #1315 from brosaplanella/issue-976-casadi-extrapolate-warning
Issue 976 casadi extrapolate warning
2 parents 0926080 + 4a64cc6 commit faa3167

13 files changed

+268
-13
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
## Bug fixes
2020

21+
- When an `Interpolant` is extrapolated an error is raised for `CasadiSolver` (and a warning is raised for the other solvers) ([#1315](https://github.com/pybamm-team/PyBaMM/pull/1315))
2122
- Fixed `Simulation` and `model.new_copy` to fix a bug where changes to the model were overwritten ([#1278](https://github.com/pybamm-team/PyBaMM/pull/1278))
2223

2324
## Breaking changes

pybamm/models/event.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class EventType(Enum):
1616

1717
TERMINATION = 0
1818
DISCONTINUITY = 1
19+
INTERPOLANT_EXTRAPOLATION = 2
1920

2021

2122
class Event:

pybamm/parameters/parameter_values.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(self, values=None, chemistry=None):
8383

8484
# Initialise empty _processed_symbols dict (for caching)
8585
self._processed_symbols = {}
86+
self.parameter_events = []
8687

8788
def __getitem__(self, key):
8889
return self._dict_items[key]
@@ -403,13 +404,24 @@ def process_model(self, unprocessed_model, inplace=True):
403404
new_events = []
404405
for event in unprocessed_model.events:
405406
pybamm.logger.debug(
406-
"Processing parameters for event'{}''".format(event.name)
407+
"Processing parameters for event '{}''".format(event.name)
407408
)
408409
new_events.append(
409410
pybamm.Event(
410411
event.name, self.process_symbol(event.expression), event.event_type
411412
)
412413
)
414+
415+
for event in self.parameter_events:
416+
pybamm.logger.debug(
417+
"Processing parameters for event '{}''".format(event.name)
418+
)
419+
new_events.append(
420+
pybamm.Event(
421+
event.name, self.process_symbol(event.expression), event.event_type
422+
)
423+
)
424+
413425
model.events = new_events
414426

415427
# Set external variables
@@ -547,6 +559,23 @@ def _process_symbol(self, symbol):
547559
function = pybamm.Interpolant(
548560
data[:, 0], data[:, 1], *new_children, name=name
549561
)
562+
# Define event to catch extrapolation. In these events the sign is
563+
# important: it should be positive inside of the range and negative
564+
# outside of it
565+
self.parameter_events.append(
566+
pybamm.Event(
567+
"Interpolant {} lower bound".format(name),
568+
new_children[0] - min(data[:, 0]),
569+
pybamm.EventType.INTERPOLANT_EXTRAPOLATION,
570+
)
571+
)
572+
self.parameter_events.append(
573+
pybamm.Event(
574+
"Interpolant {} upper bound".format(name),
575+
max(data[:, 0]) - new_children[0],
576+
pybamm.EventType.INTERPOLANT_EXTRAPOLATION,
577+
)
578+
)
550579
elif isinstance(function_name, numbers.Number):
551580
# If the "function" is provided is actually a scalar, return a Scalar
552581
# object instead of throwing an error.

pybamm/solvers/base_solver.py

+76
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import sys
1010
import itertools
11+
import warnings
1112

1213

1314
class BaseSolver(object):
@@ -30,6 +31,8 @@ class BaseSolver(object):
3031
specified by 'root_method' (e.g. "lm", "hybr", ...)
3132
root_tol : float, optional
3233
The tolerance for the initial-condition solver (default is 1e-6).
34+
extrap_tol : float, optional
35+
The tolerance to assert whether extrapolation occurs or not. Default is 0.
3336
"""
3437

3538
def __init__(
@@ -39,13 +42,15 @@ def __init__(
3942
atol=1e-6,
4043
root_method=None,
4144
root_tol=1e-6,
45+
extrap_tol=0,
4246
max_steps="deprecated",
4347
):
4448
self._method = method
4549
self._rtol = rtol
4650
self._atol = atol
4751
self.root_tol = root_tol
4852
self.root_method = root_method
53+
self.extrap_tol = extrap_tol
4954
if max_steps != "deprecated":
5055
raise ValueError(
5156
"max_steps has been deprecated, and should be set using the "
@@ -361,6 +366,12 @@ def report(string):
361366
if event.event_type == pybamm.EventType.TERMINATION
362367
]
363368

369+
interpolant_extrapolation_events_eval = [
370+
process(event.expression, "event", use_jacobian=False)[1]
371+
for event in model.events
372+
if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION
373+
]
374+
364375
# discontinuity events are evaluated before the solver is called, so don't need
365376
# to process them
366377
discontinuity_events_eval = [
@@ -376,6 +387,9 @@ def report(string):
376387
model.jac_algebraic_eval = jac_algebraic
377388
model.terminate_events_eval = terminate_events_eval
378389
model.discontinuity_events_eval = discontinuity_events_eval
390+
model.interpolant_extrapolation_events_eval = (
391+
interpolant_extrapolation_events_eval
392+
)
379393

380394
# Calculate initial conditions
381395
model.y0 = init_eval(inputs)
@@ -697,6 +711,16 @@ def solve(
697711
solution.timescale_eval = model.timescale_eval
698712
solution.length_scales_eval = model.length_scales_eval
699713

714+
# Check if extrapolation occurred
715+
extrapolation = self.check_extrapolation(solution, model.events)
716+
if extrapolation:
717+
warnings.warn(
718+
"While solving {} extrapolation occurred for {}".format(
719+
model.name, extrapolation
720+
),
721+
pybamm.SolverWarning,
722+
)
723+
700724
# Identify the event that caused termination
701725
termination = self.get_termination_reason(solution, model.events)
702726

@@ -852,6 +876,16 @@ def step(
852876
solution.timescale_eval = temp_timescale_eval
853877
solution.length_scales_eval = temp_length_scales_eval
854878

879+
# Check if extrapolation occurred
880+
extrapolation = self.check_extrapolation(solution, model.events)
881+
if extrapolation:
882+
warnings.warn(
883+
"While solving {} extrapolation occurred for {}".format(
884+
model.name, extrapolation
885+
),
886+
pybamm.SolverWarning,
887+
)
888+
855889
# Identify the event that caused termination
856890
termination = self.get_termination_reason(solution, model.events)
857891

@@ -921,6 +955,48 @@ def get_termination_reason(self, solution, events):
921955

922956
return "the termination event '{}' occurred".format(termination_event)
923957

958+
def check_extrapolation(self, solution, events):
959+
"""
960+
Check if extrapolation occurred for any of the interpolants. Note that with the
961+
current approach (evaluating all the events at the solution times) some
962+
extrapolations might not be found if they only occurred for a small period of
963+
time.
964+
965+
Parameters
966+
----------
967+
solution : :class:`pybamm.Solution`
968+
The solution object
969+
events : dict
970+
Dictionary of events
971+
"""
972+
extrap_events = {}
973+
974+
for event in events:
975+
if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION:
976+
extrap_events[event.name] = False
977+
978+
try:
979+
y_full = solution.y.full()
980+
except AttributeError:
981+
y_full = solution.y
982+
983+
for event in events:
984+
if event.event_type == pybamm.EventType.INTERPOLANT_EXTRAPOLATION:
985+
if (
986+
event.expression.evaluate(
987+
solution.t,
988+
y_full,
989+
inputs={k: v for k, v in solution.inputs.items()},
990+
)
991+
< self.extrap_tol
992+
).any():
993+
extrap_events[event.name] = True
994+
995+
# Add the event dictionaryto the solution object
996+
solution.extrap_events = extrap_events
997+
998+
return [k for k, v in extrap_events.items() if v]
999+
9241000
def _set_up_ext_and_inputs(self, model, external_variables, inputs):
9251001
"Set up external variables and input parameters"
9261002
inputs = inputs or {}

pybamm/solvers/casadi_solver.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class CasadiSolver(pybamm.BaseSolver):
4747
The maximum global step size (in seconds) used in "safe" mode. If None
4848
the default value corresponds to a non-dimensional time of 0.01
4949
(i.e. ``0.01 * model.timescale_eval``).
50+
extrap_tol : float, optional
51+
The tolerance to assert whether extrapolation occurs or not. Default is 0.
5052
extra_options_setup : dict, optional
5153
Any options to pass to the CasADi integrator when creating the integrator.
5254
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
@@ -71,10 +73,13 @@ def __init__(
7173
root_tol=1e-6,
7274
max_step_decrease_count=5,
7375
dt_max=None,
76+
extrap_tol=0,
7477
extra_options_setup=None,
7578
extra_options_call=None,
7679
):
77-
super().__init__("problem dependent", rtol, atol, root_method, root_tol)
80+
super().__init__(
81+
"problem dependent", rtol, atol, root_method, root_tol, extrap_tol
82+
)
7883
if mode in ["safe", "fast", "safe without grid"]:
7984
self.mode = mode
8085
else:
@@ -88,6 +93,7 @@ def __init__(
8893

8994
self.extra_options_setup = extra_options_setup or {}
9095
self.extra_options_call = extra_options_call or {}
96+
self.extrap_tol = extrap_tol
9197

9298
self.name = "CasADi solver with '{}' mode".format(mode)
9399

@@ -141,6 +147,33 @@ def _integrate(self, model, t_eval, inputs=None):
141147
[event(t, y0, inputs) for event in model.terminate_events_eval]
142148
)
143149
)
150+
151+
extrap_event = [
152+
event(t, y0, inputs)
153+
for event in model.interpolant_extrapolation_events_eval
154+
]
155+
156+
if extrap_event:
157+
if (np.concatenate(extrap_event) < self.extrap_tol).any():
158+
extrap_event_names = []
159+
for event in model.events:
160+
if (
161+
event.event_type
162+
== pybamm.EventType.INTERPOLANT_EXTRAPOLATION
163+
and (
164+
event.expression.evaluate(t, y0.full(), inputs=inputs,)
165+
< self.extrap_tol
166+
).any()
167+
):
168+
extrap_event_names.append(event.name[12:])
169+
170+
raise pybamm.SolverError(
171+
"CasADI solver failed because the following interpolation "
172+
"bounds were exceeded at the initial conditions: {}. "
173+
"You may need to provide additional interpolation points "
174+
"outside these bounds.".format(extrap_event_names)
175+
)
176+
144177
pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))
145178

146179
if self.mode == "safe without grid":
@@ -215,6 +248,37 @@ def _integrate(self, model, t_eval, inputs=None):
215248
]
216249
)
217250
)
251+
252+
extrap_event = [
253+
event(t, current_step_sol.y[:, -1], inputs=inputs)
254+
for event in model.interpolant_extrapolation_events_eval
255+
]
256+
257+
if extrap_event:
258+
if (np.concatenate(extrap_event) < self.extrap_tol).any():
259+
extrap_event_names = []
260+
for event in model.events:
261+
if (
262+
event.event_type
263+
== pybamm.EventType.INTERPOLANT_EXTRAPOLATION
264+
and (
265+
event.expression.evaluate(
266+
t,
267+
current_step_sol.y[:, -1].full(),
268+
inputs=inputs,
269+
)
270+
< self.extrap_tol
271+
).any()
272+
):
273+
extrap_event_names.append(event.name[12:])
274+
275+
raise pybamm.SolverError(
276+
"CasADI solver failed because the following interpolation "
277+
"bounds were exceeded: {}. You may need to provide "
278+
"additional interpolation points outside these "
279+
"bounds.".format(extrap_event_names)
280+
)
281+
218282
# Exit loop if the sign of an event changes
219283
# Locate the event time using a root finding algorithm and
220284
# event state using interpolation. The solution is then truncated

pybamm/solvers/idaklu_solver.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class IDAKLUSolver(pybamm.BaseSolver):
3636
specified by 'root_method' (e.g. "lm", "hybr", ...)
3737
root_tol : float, optional
3838
The tolerance for the initial-condition solver (default is 1e-6).
39+
extrap_tol : float, optional
40+
The tolerance to assert whether extrapolation occurs or not (default is 0).
3941
"""
4042

4143
def __init__(
@@ -44,13 +46,16 @@ def __init__(
4446
atol=1e-6,
4547
root_method="casadi",
4648
root_tol=1e-6,
49+
extrap_tol=0,
4750
max_steps="deprecated",
4851
):
4952

5053
if idaklu_spec is None:
5154
raise ImportError("KLU is not installed")
5255

53-
super().__init__("ida", rtol, atol, root_method, root_tol, max_steps)
56+
super().__init__(
57+
"ida", rtol, atol, root_method, root_tol, extrap_tol, max_steps
58+
)
5459
self.name = "IDA KLU solver"
5560

5661
pybamm.citations.register("hindmarsh2000pvode")

pybamm/solvers/jax_bdf_solver.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -715,10 +715,7 @@ def block_fun(i, j, Ai, Aj):
715715
return Ai
716716
else:
717717
return onp.zeros(
718-
(
719-
Ai.shape[0] if Ai.ndim > 1 else 1,
720-
Aj.shape[1] if Aj.ndim > 1 else 1,
721-
),
718+
(Ai.shape[0] if Ai.ndim > 1 else 1, Aj.shape[1] if Aj.ndim > 1 else 1,),
722719
dtype=Ai.dtype,
723720
)
724721

pybamm/solvers/jax_solver.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class JaxSolver(pybamm.BaseSolver):
3838
The relative tolerance for the solver (default is 1e-6).
3939
atol : float, optional
4040
The absolute tolerance for the solver (default is 1e-6).
41+
extrap_tol : float, optional
42+
The tolerance to assert whether extrapolation occurs or not (default is 0).
4143
extra_options : dict, optional
4244
Any options to pass to the solver.
4345
Please consult `JAX documentation
@@ -46,11 +48,19 @@ class JaxSolver(pybamm.BaseSolver):
4648
"""
4749

4850
def __init__(
49-
self, method="RK45", root_method=None, rtol=1e-6, atol=1e-6, extra_options=None
51+
self,
52+
method="RK45",
53+
root_method=None,
54+
rtol=1e-6,
55+
atol=1e-6,
56+
extrap_tol=0,
57+
extra_options=None,
5058
):
5159
# note: bdf solver itself calculates consistent initial conditions so can set
5260
# root_method to none, allow user to override this behavior
53-
super().__init__(method, rtol, atol, root_method=root_method)
61+
super().__init__(
62+
method, rtol, atol, root_method=root_method, extrap_tol=extrap_tol
63+
)
5464
method_options = ["RK45", "BDF"]
5565
if method not in method_options:
5666
raise ValueError("method must be one of {}".format(method_options))

0 commit comments

Comments
 (0)