Skip to content

Commit 306b98e

Browse files
committedFeb 7, 2020
#759 solver checks for heaviside functions and adds appropriate discontinuity events
1 parent 09e1694 commit 306b98e

File tree

2 files changed

+80
-43
lines changed

2 files changed

+80
-43
lines changed
 

‎pybamm/solvers/base_solver.py

+13
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,19 @@ def report(string):
212212
jac_call = None
213213
return func, func_call, jac_call
214214

215+
# Check for heaviside functions in rhs and algebraic and add discontinuity
216+
# events if these exist.
217+
# Note: only checks for the case of t < X, t <= X, X < t, or X <= t
218+
for symbol in model.concatenated_rhs.pre_order():
219+
if isinstance(symbol, pybamm.Heaviside):
220+
if symbol.right.id == pybamm.t.id:
221+
expr = symbol.left
222+
elif symbol.left.id == pybamm.t.id:
223+
expr = symbol.right
224+
225+
model.events.append(pybamm.Event(str(symbol), expr.new_copy(),
226+
pybamm.EventType.DISCONTINUITY))
227+
215228
# Process rhs, algebraic and event expressions
216229
rhs, rhs_eval, jac_rhs = process(model.concatenated_rhs, "RHS")
217230
algebraic, algebraic_eval, jac_algebraic = process(

‎tests/unit/test_solvers/test_scipy_solver.py

+67-43
Original file line numberDiff line numberDiff line change
@@ -139,65 +139,89 @@ def jacobian(t, y):
139139
)
140140

141141
def test_model_solver_ode_nonsmooth(self):
142-
model = pybamm.BaseModel()
143142
whole_cell = ["negative electrode", "separator", "positive electrode"]
144143
var1 = pybamm.Variable("var1", domain=whole_cell)
145144
discontinuity = 0.6
146145

146+
# Create three different models with the same solution, each expressing the
147+
# discontinuity in a different way
148+
149+
# first model explicitly adds a discontinuity event
147150
def nonsmooth_rate(t):
148151
return 0.1 * (t < discontinuity) + 0.1
149152

150153
rate = pybamm.Function(nonsmooth_rate, pybamm.t)
151-
model.rhs = {var1: rate * var1}
152-
model.initial_conditions = {var1: 1}
153-
model.events = [
154+
model1 = pybamm.BaseModel()
155+
model1.rhs = {var1: rate * var1}
156+
model1.initial_conditions = {var1: 1}
157+
model1.events = [
154158
pybamm.Event("var1 = 1.5", pybamm.min(var1 - 1.5)),
155159
pybamm.Event("nonsmooth rate",
156160
pybamm.Scalar(discontinuity),
157161
pybamm.EventType.DISCONTINUITY
158162
),
159163
]
160-
disc = get_discretisation_for_testing()
161-
disc.process_model(model)
162164

163-
# Solve
164-
solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8)
165+
# second model implicitly adds a discontinuity event via a heaviside function
166+
model2 = pybamm.BaseModel()
167+
model2.rhs = {var1: (0.1 * (pybamm.t < discontinuity) + 0.1) * var1}
168+
model2.initial_conditions = {var1: 1}
169+
model2.events = [
170+
pybamm.Event("var1 = 1.5", pybamm.min(var1 - 1.5)),
171+
]
165172

166-
# create two time series, one without a time point on the discontinuity,
167-
# and one with
168-
t_eval1 = np.linspace(0, 5, 10)
169-
t_eval2 = np.insert(t_eval1,
170-
np.searchsorted(t_eval1, discontinuity),
171-
discontinuity)
172-
solution1 = solver.solve(model, t_eval1)
173-
solution2 = solver.solve(model, t_eval2)
174-
175-
# check time vectors
176-
for solution in [solution1, solution2]:
177-
# time vectors are ordered
178-
self.assertTrue(np.all(solution.t[:-1] <= solution.t[1:]))
179-
180-
# time value before and after discontinuity is an epsilon away
181-
dindex = np.searchsorted(solution.t, discontinuity)
182-
value_before = solution.t[dindex - 1]
183-
value_after = solution.t[dindex]
184-
self.assertEqual(value_before + sys.float_info.epsilon, discontinuity)
185-
self.assertEqual(value_after - sys.float_info.epsilon, discontinuity)
186-
187-
# both solution time vectors should have same number of points
188-
self.assertEqual(len(solution1.t), len(solution2.t))
189-
190-
# check solution
191-
for solution in [solution1, solution2]:
192-
np.testing.assert_array_less(solution.y[0], 1.5)
193-
np.testing.assert_array_less(solution.y[-1], 2.5)
194-
var1_soln = np.exp(0.2 * solution.t)
195-
y0 = np.exp(0.2 * discontinuity)
196-
var1_soln[solution.t > discontinuity] = \
197-
y0 * np.exp(
198-
0.1 * (solution.t[solution.t > discontinuity] - discontinuity)
199-
)
200-
np.testing.assert_allclose(solution.y[0], var1_soln, rtol=1e-06)
173+
# third model implicitly adds a discontinuity event via another heaviside
174+
# function
175+
model3 = pybamm.BaseModel()
176+
model3.rhs = {var1: (-0.1 * (discontinuity < pybamm.t) + 0.2) * var1}
177+
model3.initial_conditions = {var1: 1}
178+
model3.events = [
179+
pybamm.Event("var1 = 1.5", pybamm.min(var1 - 1.5)),
180+
]
181+
182+
for model in [model1, model2, model3]:
183+
184+
disc = get_discretisation_for_testing()
185+
disc.process_model(model)
186+
187+
# Solve
188+
solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8)
189+
190+
# create two time series, one without a time point on the discontinuity,
191+
# and one with
192+
t_eval1 = np.linspace(0, 5, 10)
193+
t_eval2 = np.insert(t_eval1,
194+
np.searchsorted(t_eval1, discontinuity),
195+
discontinuity)
196+
solution1 = solver.solve(model, t_eval1)
197+
solution2 = solver.solve(model, t_eval2)
198+
199+
# check time vectors
200+
for solution in [solution1, solution2]:
201+
# time vectors are ordered
202+
self.assertTrue(np.all(solution.t[:-1] <= solution.t[1:]))
203+
204+
# time value before and after discontinuity is an epsilon away
205+
dindex = np.searchsorted(solution.t, discontinuity)
206+
value_before = solution.t[dindex - 1]
207+
value_after = solution.t[dindex]
208+
self.assertEqual(value_before + sys.float_info.epsilon, discontinuity)
209+
self.assertEqual(value_after - sys.float_info.epsilon, discontinuity)
210+
211+
# both solution time vectors should have same number of points
212+
self.assertEqual(len(solution1.t), len(solution2.t))
213+
214+
# check solution
215+
for solution in [solution1, solution2]:
216+
np.testing.assert_array_less(solution.y[0], 1.5)
217+
np.testing.assert_array_less(solution.y[-1], 2.5)
218+
var1_soln = np.exp(0.2 * solution.t)
219+
y0 = np.exp(0.2 * discontinuity)
220+
var1_soln[solution.t > discontinuity] = \
221+
y0 * np.exp(
222+
0.1 * (solution.t[solution.t > discontinuity] - discontinuity)
223+
)
224+
np.testing.assert_allclose(solution.y[0], var1_soln, rtol=1e-06)
201225

202226
def test_model_step_python(self):
203227
# Create model

0 commit comments

Comments
 (0)