Skip to content

Commit 50ca0fc

Browse files
committed
#759 discontinuity events seem to be working now
1 parent 1b42969 commit 50ca0fc

9 files changed

+224
-57
lines changed

pybamm/discretisations/discretisation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def process_model(self, model, inplace=True, check_model=True):
176176
processed_events = []
177177
pybamm.logger.info("Discretise events for {}".format(model.name))
178178
for event in model.events:
179-
pybamm.logger.debug("Discretise event '{}'".format(event))
179+
pybamm.logger.debug("Discretise event '{}'".format(event.name))
180180
processed_event = pybamm.Event(
181181
event.name,
182182
self.process_symbol(event.expression),

pybamm/solvers/base_solver.py

+84-17
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
from scipy import optimize
99
from scipy.sparse import issparse
10+
import sys
1011

1112

1213
class BaseSolver(object):
@@ -218,13 +219,15 @@ def report(string):
218219
)
219220
terminate_events_eval = [
220221
process(event.expression, "event", use_jacobian=False)[1]
221-
for event in model.events
222-
if events.type == pybamm.EventType.TERMINATION
222+
for event in model.events
223+
if event.event_type == pybamm.EventType.TERMINATION
223224
]
225+
226+
# discontinuity events are evaluated before the solver is called, so don't need
227+
# to process them
224228
discontinuity_events_eval = [
225-
process(event.expression, "event", use_jacobian=False)[1]
226-
for event in model.events
227-
if events.type == pybamm.EventType.DISCONTINUITY
229+
event for event in model.events
230+
if event.event_type == pybamm.EventType.DISCONTINUITY
228231
]
229232

230233
# Add the solver attributes
@@ -243,7 +246,8 @@ def report(string):
243246
residuals, residuals_eval, jacobian_eval = process(all_states, "residuals")
244247
model.residuals_eval = residuals_eval
245248
model.jacobian_eval = jacobian_eval
246-
model.y0 = self.calculate_consistent_initial_conditions(model)
249+
y0_guess = model.concatenated_initial_conditions.flatten()
250+
model.y0 = self.calculate_consistent_state(model, 0, y0_guess)
247251
else:
248252
# can use DAE solver to solve ODE model
249253
model.residuals_eval = Residuals(rhs, "residuals", model)
@@ -281,14 +285,12 @@ def set_inputs(self, model, ext_and_inputs):
281285
model.residuals_eval.set_inputs(ext_and_inputs)
282286
for evnt in model.terminate_events_eval:
283287
evnt.set_inputs(ext_and_inputs)
284-
for evnt in model.discontinuity_events_eval:
285-
evnt.set_inputs(ext_and_inputs)
286288
if model.jacobian_eval:
287289
model.jacobian_eval.set_inputs(ext_and_inputs)
288290

289-
def calculate_consistent_initial_conditions(self, model):
291+
def calculate_consistent_state(self, model, time=0, y0_guess=None):
290292
"""
291-
Calculate consistent initial conditions for the algebraic equations through
293+
Calculate consistent state for the algebraic equations through
292294
root-finding
293295
294296
Parameters
@@ -305,8 +307,9 @@ def calculate_consistent_initial_conditions(self, model):
305307
pybamm.logger.info("Start calculating consistent initial conditions")
306308
rhs = model.rhs_eval
307309
algebraic = model.algebraic_eval
308-
y0_guess = model.concatenated_initial_conditions.flatten()
309310
jac = model.jac_algebraic_eval
311+
if y0_guess is None:
312+
y0_guess = model.concatenated_initial_conditions.flatten()
310313

311314
# Split y0_guess into differential and algebraic
312315
len_rhs = rhs(0, y0_guess).shape[0]
@@ -315,7 +318,7 @@ def calculate_consistent_initial_conditions(self, model):
315318
def root_fun(y0_alg):
316319
"Evaluates algebraic using y0_diff (fixed) and y0_alg (changed by algo)"
317320
y0 = np.concatenate([y0_diff, y0_alg])
318-
out = algebraic(0, y0)
321+
out = algebraic(time, y0)
319322
pybamm.logger.debug(
320323
"Evaluating algebraic equations at t=0, L2-norm is {}".format(
321324
np.linalg.norm(out)
@@ -421,13 +424,77 @@ def solve(self, model, t_eval, external_variables=None, inputs=None):
421424
# Set inputs and external
422425
self.set_inputs(model, ext_and_inputs)
423426

424-
timer.reset()
425-
pybamm.logger.info("Calling solver")
426-
solution = self._integrate(model, t_eval, ext_and_inputs)
427+
# Calculate discontinuities
428+
discontinuities = [
429+
event.expression.evaluate(u=inputs) for event in model.discontinuity_events_eval
430+
]
431+
432+
# make sure they are increasing in time
433+
discontinuities = sorted(discontinuities)
434+
pybamm.logger.info(
435+
'Discontinuity events found at t = {}'.format(discontinuities)
436+
)
437+
# remove any identical discontinuities
438+
discontinuities = [
439+
v for i, v in enumerate(discontinuities)
440+
if i==len(discontinuities)-1 or discontinuities[i] < discontinuities[i+1]
441+
]
442+
443+
# insert time points around discontinuities in t_eval
444+
# keep track of sub sections to integrate by storing start and end indices
445+
start_indices = [0]
446+
end_indices = []
447+
for dtime in discontinuities:
448+
dindex = np.searchsorted(t_eval, dtime, side='left')
449+
end_indices.append(dindex+1)
450+
start_indices.append(dindex+1)
451+
if t_eval[dindex] == dtime:
452+
t_eval[dindex] += sys.float_info.epsilon
453+
t_eval = np.insert(t_eval, dindex, dtime - sys.float_info.epsilon)
454+
else:
455+
t_eval = np.insert(t_eval, dindex,
456+
[dtime - sys.float_info.epsilon, dtime + sys.float_info.epsilon])
457+
end_indices.append(len(t_eval))
458+
459+
old_y0 = model.y0
460+
solution = None
461+
for start_index, end_index in zip(start_indices, end_indices):
462+
pybamm.logger.info("Calling solver for {} < t < {}"
463+
.format(t_eval[start_index], t_eval[end_index-1]))
464+
timer.reset()
465+
if solution is None:
466+
solution = self._integrate(
467+
model, t_eval[start_index:end_index], ext_and_inputs)
468+
solution.solve_time = timer.time()
469+
else:
470+
new_solution = self._integrate(
471+
model, t_eval[start_index:end_index], ext_and_inputs)
472+
new_solution.solve_time = timer.time()
473+
solution.append(new_solution, start_index=0)
474+
475+
if solution.termination != "final time":
476+
break
477+
478+
if end_index != len(t_eval):
479+
# setup for next integration subsection
480+
y0_guess = solution.y[:, -1]
481+
if model.algebraic:
482+
model.y0 = self.calculate_consistent_state(model, t_eval[end_index], y0_guess)
483+
else:
484+
model.y0 = y0_guess
485+
486+
last_state = solution.y[:, -1]
487+
if len(model.algebraic) > 0:
488+
model.y0 = self.calculate_consistent_state(
489+
model, t_eval[end_index], last_state)
490+
else:
491+
model.y0 = last_state
492+
493+
# restore old y0
494+
model.y0 = old_y0
427495

428496
# Assign times
429497
solution.set_up_time = set_up_time
430-
solution.solve_time = timer.time()
431498

432499
# Add model and inputs to solution
433500
solution.model = model
@@ -571,7 +638,7 @@ def get_termination_reason(self, solution, events):
571638
final_event_values = {}
572639

573640
for event in events:
574-
if event.type == pybamm.EventType.TERMINATION:
641+
if event.event_type == pybamm.EventType.TERMINATION:
575642
final_event_values[event.name] = abs(
576643
event.expression.evaluate(
577644
solution.t_event,

pybamm/solvers/scikits_dae_solver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _integrate(self, model, t_eval, inputs=None):
6666
"""
6767
residuals = model.residuals_eval
6868
y0 = model.y0
69-
events = model.events_eval
69+
events = model.terminate_events_eval
7070
jacobian = model.jacobian_eval
7171
mass_matrix = model.mass_matrix.entries
7272

pybamm/solvers/scikits_ode_solver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _integrate(self, model, t_eval, inputs=None):
5959
"""
6060
derivs = model.rhs_eval
6161
y0 = model.y0
62-
events = model.events_eval
62+
events = model.terminate_events_eval
6363
jacobian = model.jacobian_eval
6464

6565
def eqsydot(t, y, return_ydot):

pybamm/solvers/scipy_solver.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def _integrate(self, model, t_eval, inputs=None):
5454
extra_options.update({"jac": model.jacobian_eval})
5555

5656
# make events terminal so that the solver stops when they are reached
57-
if model.events_eval:
58-
for event in model.events_eval:
57+
if model.terminate_events_eval:
58+
for event in model.terminate_events_eval:
5959
event.terminal = True
60-
extra_options.update({"events": model.events_eval})
60+
extra_options.update({"events": model.terminate_events_eval})
6161

6262
sol = it.solve_ivp(
6363
model.rhs_eval,

pybamm/solvers/solution.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -129,22 +129,25 @@ def __add__(self, other):
129129
self.append(other)
130130
return self
131131

132-
def append(self, solution):
132+
def append(self, solution, start_index=1):
133133
"""
134+
134135
Appends solution.t and solution.y onto self.t and self.y.
135-
Note: this process removes the initial time and state of solution to avoid
136-
duplicate times and states being stored (self.t[-1] is equal to solution.t[0],
137-
and self.y[:, -1] is equal to solution.y[:, 0]).
136+
137+
Note: by default this process removes the initial time and state of solution to
138+
avoid duplicate times and states being stored (self.t[-1] is equal to
139+
solution.t[0], and self.y[:, -1] is equal to solution.y[:, 0]). Set the optional
140+
argument ``start_index`` to override this behavior
138141
139142
"""
140143
# Update t, y and inputs
141-
self.t = np.concatenate((self.t, solution.t[1:]))
142-
self.y = np.concatenate((self.y, solution.y[:, 1:]), axis=1)
144+
self.t = np.concatenate((self.t, solution.t[start_index:]))
145+
self.y = np.concatenate((self.y, solution.y[:, start_index:]), axis=1)
143146
for name, inp in self.inputs.items():
144147
solution_inp = solution.inputs[name]
145148
if isinstance(solution_inp, numbers.Number):
146149
solution_inp = solution_inp * np.ones_like(solution.t)
147-
self.inputs[name] = np.concatenate((inp, solution_inp[1:]))
150+
self.inputs[name] = np.concatenate((inp, solution_inp[start_index:]))
148151
# Update solution time
149152
self.solve_time += solution.solve_time
150153
# Update termination

tests/unit/test_solvers/test_base_solver.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def algebraic_eval(self, t, y):
5858
return y + 2
5959

6060
solver = pybamm.BaseSolver()
61-
init_cond = solver.calculate_consistent_initial_conditions(ScalarModel())
61+
model = ScalarModel()
62+
init_cond = solver.calculate_consistent_state(model)
6263
np.testing.assert_array_equal(init_cond, -2)
6364

6465
# More complicated system
@@ -75,15 +76,15 @@ def algebraic_eval(self, t, y):
7576
return (y[1:] - vec[1:]) ** 2
7677

7778
model = VectorModel()
78-
init_cond = solver.calculate_consistent_initial_conditions(model)
79+
init_cond = solver.calculate_consistent_state(model)
7980
np.testing.assert_array_almost_equal(init_cond, vec)
8081

8182
# With jacobian
8283
def jac_dense(t, y):
8384
return 2 * np.hstack([np.zeros((3, 1)), np.diag(y[1:] - vec[1:])])
8485

8586
model.jac_algebraic_eval = jac_dense
86-
init_cond = solver.calculate_consistent_initial_conditions(model)
87+
init_cond = solver.calculate_consistent_state(model)
8788
np.testing.assert_array_almost_equal(init_cond, vec)
8889

8990
# With sparse jacobian
@@ -93,7 +94,7 @@ def jac_sparse(t, y):
9394
)
9495

9596
model.jac_algebraic_eval = jac_sparse
96-
init_cond = solver.calculate_consistent_initial_conditions(model)
97+
init_cond = solver.calculate_consistent_state(model)
9798
np.testing.assert_array_almost_equal(init_cond, vec)
9899

99100
def test_fail_consistent_initial_conditions(self):
@@ -114,13 +115,13 @@ def algebraic_eval(self, t, y):
114115
pybamm.SolverError,
115116
"Could not find consistent initial conditions: The iteration is not making",
116117
):
117-
solver.calculate_consistent_initial_conditions(Model())
118+
solver.calculate_consistent_state(Model())
118119
solver = pybamm.BaseSolver()
119120
with self.assertRaisesRegex(
120121
pybamm.SolverError,
121122
"Could not find consistent initial conditions: solver terminated",
122123
):
123-
solver.calculate_consistent_initial_conditions(Model())
124+
solver.calculate_consistent_state(Model())
124125

125126

126127
if __name__ == "__main__":

0 commit comments

Comments
 (0)