Skip to content

Commit 08bb097

Browse files
#943 recalculate initial conditions
1 parent 5e43789 commit 08bb097

File tree

4 files changed

+165
-29
lines changed

4 files changed

+165
-29
lines changed

pybamm/solvers/base_solver.py

+69-29
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ def set_up(self, model, inputs=None):
142142
)
143143

144144
inputs = inputs or {}
145-
y0 = model.concatenated_initial_conditions.evaluate(0, None, inputs=inputs)
145+
model.y0 = model.concatenated_initial_conditions.evaluate(
146+
0, None, inputs=inputs
147+
).flatten()
146148

147149
# Set model timescale
148150
model.timescale_eval = model.timescale.evaluate(inputs=inputs)
@@ -169,18 +171,19 @@ def set_up(self, model, inputs=None):
169171
if model.convert_to_format != "casadi":
170172
simp = pybamm.Simplification()
171173
# Create Jacobian from concatenated rhs and algebraic
172-
y = pybamm.StateVector(slice(0, np.size(y0)))
174+
y = pybamm.StateVector(slice(0, np.size(model.y0)))
173175
# set up Jacobian object, for re-use of dict
174176
jacobian = pybamm.Jacobian()
175177
else:
176178
# Convert model attributes to casadi
177179
t_casadi = casadi.MX.sym("t")
178180
y_diff = casadi.MX.sym(
179-
"y_diff", len(model.concatenated_rhs.evaluate(0, y0, inputs=inputs))
181+
"y_diff",
182+
len(model.concatenated_rhs.evaluate(0, model.y0, inputs=inputs)),
180183
)
181184
y_alg = casadi.MX.sym(
182185
"y_alg",
183-
len(model.concatenated_algebraic.evaluate(0, y0, inputs=inputs)),
186+
len(model.concatenated_algebraic.evaluate(0, model.y0, inputs=inputs)),
184187
)
185188
y_casadi = casadi.vertcat(y_diff, y_alg)
186189
p_casadi = {}
@@ -322,36 +325,69 @@ def report(string):
322325
"rhs", [t_casadi, y_casadi, p_casadi_stacked], [explicit_rhs]
323326
)
324327
model.casadi_algebraic = algebraic
325-
if self.algebraic_solver is True:
326-
# we don't calculate consistent initial conditions
327-
# for an algebraic solver as this will be the job of the algebraic solver
328+
if len(model.rhs) == 0:
329+
# No rhs equations: residuals is algebraic only
328330
model.residuals_eval = Residuals(algebraic, "residuals", model)
329331
model.jacobian_eval = jac_algebraic
330-
model.y0 = y0.flatten()
331332
elif len(model.algebraic) == 0:
332-
# can use DAE solver to solve ODE model
333-
# - no initial condition initialization needed
333+
# No algebraic equations: residuals is rhs only
334334
model.residuals_eval = Residuals(rhs, "residuals", model)
335335
model.jacobian_eval = jac_rhs
336-
model.y0 = y0.flatten()
337336
# Calculate consistent initial conditions for the algebraic equations
338337
else:
339-
if len(model.rhs) > 0:
340-
all_states = pybamm.NumpyConcatenation(
341-
model.concatenated_rhs, model.concatenated_algebraic
338+
all_states = pybamm.NumpyConcatenation(
339+
model.concatenated_rhs, model.concatenated_algebraic
340+
)
341+
# Process again, uses caching so should be quick
342+
residuals_eval, jacobian_eval = process(all_states, "residuals")[1:]
343+
model.residuals_eval = residuals_eval
344+
model.jacobian_eval = jacobian_eval
345+
346+
pybamm.logger.info("Finish solver set-up")
347+
348+
def _set_initial_conditions(self, model, inputs, update_rhs):
349+
"""
350+
Set initial conditions for the model. This is skipped if the solver is an
351+
algebraic solver (since this would make the algebraic solver redundant), and if
352+
the model doesn't have any algebraic equations (since there are no initial
353+
conditions to be calculated in this case).
354+
355+
Parameters
356+
----------
357+
model : :class:`pybamm.BaseModel`
358+
The model for which to calculate initial conditions.
359+
inputs : dict
360+
Any input parameters to pass to the model when solving
361+
update_rhs : bool
362+
Whether to update the rhs. True for 'solve', False for 'step'.
363+
364+
"""
365+
if self.algebraic_solver is True:
366+
return None
367+
elif len(model.algebraic) == 0:
368+
if update_rhs is True:
369+
# Recalculate initial conditions for the rhs equations
370+
model.y0 = model.concatenated_initial_conditions.evaluate(
371+
0, None, inputs=inputs
372+
).flatten()
373+
else:
374+
return None
375+
else:
376+
if update_rhs is True:
377+
# Recalculate initial conditions for the rhs equations
378+
y0_from_inputs = model.concatenated_initial_conditions.evaluate(
379+
0, None, inputs=inputs
380+
).flatten()
381+
# Reuse old solution for algebraic equations
382+
y0_from_model = model.y0
383+
len_rhs = len(
384+
model.concatenated_rhs.evaluate(0, model.y0, inputs=inputs)
342385
)
343-
# Process again, uses caching so should be quick
344-
residuals_eval, jacobian_eval = process(all_states, "residuals")[1:]
345-
model.residuals_eval = residuals_eval
346-
model.jacobian_eval = jacobian_eval
386+
y0_guess = np.r_[y0_from_inputs[:len_rhs], y0_from_model[len_rhs:]]
347387
else:
348-
model.residuals_eval = Residuals(algebraic, "residuals", model)
349-
model.jacobian_eval = jac_algebraic
350-
y0_guess = y0.flatten()
388+
y0_guess = model.y0
351389
model.y0 = self.calculate_consistent_state(model, 0, y0_guess, inputs)
352390

353-
pybamm.logger.info("Finish solver set-up")
354-
355391
def calculate_consistent_state(self, model, time=0, y0_guess=None, inputs=None):
356392
"""
357393
Calculate consistent state for the algebraic equations through
@@ -480,12 +516,9 @@ def jac_fn(y0_alg):
480516
)
481517
else:
482518
raise pybamm.SolverError(
483-
"""
484-
Could not find consistent initial conditions: solver terminated
485-
successfully, but maximum solution error ({}) above tolerance ({})
486-
""".format(
487-
max_fun, self.root_tol
488-
)
519+
"Could not find consistent initial conditions: solver terminated "
520+
"successfully, but maximum solution error "
521+
"({}) above tolerance ({})".format(max_fun, self.root_tol)
489522
)
490523

491524
def solve(self, model, t_eval=None, external_variables=None, inputs=None):
@@ -555,6 +588,10 @@ def solve(self, model, t_eval=None, external_variables=None, inputs=None):
555588
self.models_set_up.add(model)
556589
else:
557590
set_up_time = 0
591+
592+
# (Re-)calculate consistent initial conditions
593+
self._set_initial_conditions(model, ext_and_inputs, update_rhs=True)
594+
558595
# Non-dimensionalise time
559596
t_eval_dimensionless = t_eval / model.timescale_eval
560597
# Solve
@@ -758,6 +795,9 @@ def step(
758795
model.y0 = old_solution.y[:, -1]
759796
set_up_time = 0
760797

798+
# (Re-)calculate consistent initial conditions
799+
self._set_initial_conditions(model, ext_and_inputs, update_rhs=False)
800+
761801
# Non-dimensionalise dt
762802
dt_dimensionless = dt / model.timescale_eval
763803
# Step

tests/unit/test_solvers/test_casadi_solver.py

+36
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,42 @@ def test_model_solver_with_inputs(self):
255255
np.testing.assert_array_equal(solution.t, t_eval[: len(solution.t)])
256256
np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t), rtol=1e-06)
257257

258+
def test_model_solver_dae_inputs_in_initial_conditions(self):
259+
# Create model
260+
model = pybamm.BaseModel()
261+
var1 = pybamm.Variable("var1")
262+
var2 = pybamm.Variable("var2")
263+
model.rhs = {var1: pybamm.InputParameter("rate") * var1}
264+
model.algebraic = {var2: var1 - var2}
265+
model.initial_conditions = {
266+
var1: pybamm.InputParameter("ic 1"),
267+
var2: pybamm.InputParameter("ic 2"),
268+
}
269+
270+
# Solve
271+
solver = pybamm.CasadiSolver(rtol=1e-8, atol=1e-8)
272+
t_eval = np.linspace(0, 5, 100)
273+
solution = solver.solve(
274+
model, t_eval, inputs={"rate": -1, "ic 1": 0.1, "ic 2": 2}
275+
)
276+
np.testing.assert_array_almost_equal(
277+
solution.y[0], 0.1 * np.exp(-solution.t), decimal=5
278+
)
279+
np.testing.assert_array_almost_equal(
280+
solution.y[-1], 0.1 * np.exp(-solution.t), decimal=5
281+
)
282+
283+
# Solve again with different initial conditions
284+
solution = solver.solve(
285+
model, t_eval, inputs={"rate": -0.1, "ic 1": 1, "ic 2": 3}
286+
)
287+
np.testing.assert_array_almost_equal(
288+
solution.y[0], 1 * np.exp(-0.1 * solution.t), decimal=5
289+
)
290+
np.testing.assert_array_almost_equal(
291+
solution.y[-1], 1 * np.exp(-0.1 * solution.t), decimal=5
292+
)
293+
258294
def test_model_solver_with_external(self):
259295
# Create model
260296
model = pybamm.BaseModel()

tests/unit/test_solvers/test_scikits_solvers.py

+37
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_dae_integrate_bad_ics(self):
7272

7373
t_eval = np.linspace(0, 1, 100)
7474
solver.set_up(model)
75+
solver._set_initial_conditions(model, {}, True)
7576
# check y0
7677
np.testing.assert_array_equal(model.y0, [0, 0])
7778
# check dae solutions
@@ -564,6 +565,42 @@ def test_model_solver_dae_inputs_events(self):
564565
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))
565566
np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t))
566567

568+
def test_model_solver_dae_inputs_in_initial_conditions(self):
569+
# Create model
570+
model = pybamm.BaseModel()
571+
var1 = pybamm.Variable("var1")
572+
var2 = pybamm.Variable("var2")
573+
model.rhs = {var1: pybamm.InputParameter("rate") * var1}
574+
model.algebraic = {var2: var1 - var2}
575+
model.initial_conditions = {
576+
var1: pybamm.InputParameter("ic 1"),
577+
var2: pybamm.InputParameter("ic 2"),
578+
}
579+
580+
# Solve
581+
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
582+
t_eval = np.linspace(0, 5, 100)
583+
solution = solver.solve(
584+
model, t_eval, inputs={"rate": -1, "ic 1": 0.1, "ic 2": 2}
585+
)
586+
np.testing.assert_array_almost_equal(
587+
solution.y[0], 0.1 * np.exp(-solution.t), decimal=5
588+
)
589+
np.testing.assert_array_almost_equal(
590+
solution.y[-1], 0.1 * np.exp(-solution.t), decimal=5
591+
)
592+
593+
# Solve again with different initial conditions
594+
solution = solver.solve(
595+
model, t_eval, inputs={"rate": -0.1, "ic 1": 1, "ic 2": 3}
596+
)
597+
np.testing.assert_array_almost_equal(
598+
solution.y[0], 1 * np.exp(-0.1 * solution.t), decimal=5
599+
)
600+
np.testing.assert_array_almost_equal(
601+
solution.y[-1], 1 * np.exp(-0.1 * solution.t), decimal=5
602+
)
603+
567604
def test_model_solver_dae_with_external(self):
568605
# Create model
569606
model = pybamm.BaseModel()

tests/unit/test_solvers/test_scipy_solver.py

+23
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,29 @@ def test_model_solver_with_inputs_with_casadi(self):
288288
np.testing.assert_array_equal(solution.t, t_eval[: len(solution.t)])
289289
np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t))
290290

291+
def test_model_solver_inputs_in_initial_conditions(self):
292+
# Create model
293+
model = pybamm.BaseModel()
294+
var1 = pybamm.Variable("var1")
295+
model.rhs = {var1: pybamm.InputParameter("rate") * var1}
296+
model.initial_conditions = {
297+
var1: pybamm.InputParameter("ic 1"),
298+
}
299+
300+
# Solve
301+
solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8)
302+
t_eval = np.linspace(0, 5, 100)
303+
solution = solver.solve(model, t_eval, inputs={"rate": -1, "ic 1": 0.1})
304+
np.testing.assert_array_almost_equal(
305+
solution.y[0], 0.1 * np.exp(-solution.t), decimal=5
306+
)
307+
308+
# Solve again with different initial conditions
309+
solution = solver.solve(model, t_eval, inputs={"rate": -0.1, "ic 1": 1})
310+
np.testing.assert_array_almost_equal(
311+
solution.y[0], 1 * np.exp(-0.1 * solution.t), decimal=5
312+
)
313+
291314

292315
if __name__ == "__main__":
293316
print("Add -v for more debug output")

0 commit comments

Comments
 (0)