Skip to content

Commit 46ed98e

Browse files
#804 fix tests
1 parent c029abc commit 46ed98e

File tree

4 files changed

+111
-47
lines changed

4 files changed

+111
-47
lines changed

pybamm/solvers/base_solver.py

+51-27
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ class BaseSolver(object):
2323
atol : float, optional
2424
The absolute tolerance for the solver (default is 1e-6).
2525
root_method : str, optional
26-
The method to use to find initial conditions (default is "lm")
26+
The method to use to find initial conditions (default is "casadi"). If "casadi",
27+
the solver uses casadi's Newton rootfinding algorithm to find initial
28+
conditions. Otherwise, the solver uses 'scipy.optimize.root' with method
29+
specified by 'root_method' (e.g. "lm", "hybr", ...)
2730
root_tol : float, optional
2831
The tolerance for the initial-condition solver (default is 1e-6).
2932
max_steps: int, optional
@@ -125,8 +128,11 @@ def set_up(self, model, inputs=None):
125128
"Cannot use ODE solver '{}' to solve DAE model".format(self.name)
126129
)
127130

131+
if self.ode_solver is True:
132+
self.root_method = None
128133
if (
129-
isinstance(self, pybamm.CasadiSolver) or self.root_method == "casadi"
134+
isinstance(self, pybamm.CasadiSolver)
135+
or self.root_method == "casadi"
130136
) and model.convert_to_format != "casadi":
131137
pybamm.logger.warning(
132138
f"Converting {model.name} to CasADi for solving with CasADi solver"
@@ -349,17 +355,32 @@ def calculate_consistent_state(self, model, time=0, y0_guess=None, inputs=None):
349355
# Solve using casadi or scipy
350356
if self.root_method == "casadi":
351357
# Set up
352-
print("yp")
353358
u_stacked = casadi.vertcat(*[x for x in inputs.values()])
354359
u = casadi.MX.sym("u", u_stacked.shape[0])
355-
alg = model.casadi_algebraic
356360
y_alg = casadi.MX.sym("y_alg", y0_alg_guess.shape[0])
357361
y = casadi.vertcat(y0_diff, y_alg)
358-
alg_root = alg(time, y, u)
362+
alg_root = model.casadi_algebraic(time, y, u)
359363
# Solve
360-
roots = casadi.rootfinder("roots", "newton", dict(x=y_alg, p=u, g=alg_root))
361-
y0_alg = roots(y0_alg_guess, u_stacked).full().flatten()
362-
return np.concatenate([y0_diff, y0_alg])
364+
try:
365+
# set error_on_fail to False and just check the final output is small
366+
# enough
367+
roots = casadi.rootfinder(
368+
"roots",
369+
"newton",
370+
dict(x=y_alg, p=u, g=alg_root),
371+
{"error_on_fail": False},
372+
)
373+
y0_alg = roots(y0_alg_guess, u_stacked).full().flatten()
374+
success = True
375+
message = None
376+
# Check final output
377+
fun = model.casadi_algebraic(
378+
time, casadi.vertcat(y0_diff, y0_alg), u_stacked
379+
)
380+
except RuntimeError as err:
381+
success = False
382+
message = err.args[0]
383+
fun = None
363384
else:
364385
algebraic = model.algebraic_eval
365386
jac = model.jac_algebraic_eval
@@ -404,27 +425,30 @@ def jac_fn(y0_alg):
404425
method=self.root_method,
405426
tol=self.root_tol,
406427
)
428+
# Set outputs
429+
y0_alg = sol.x
430+
success = sol.success
431+
fun = sol.fun
432+
message = sol.message
433+
434+
if success and np.all(fun < self.root_tol * len(y0_alg)):
407435
# Return full set of consistent initial conditions (y0_diff unchanged)
408-
y0_consistent = np.concatenate([y0_diff, sol.x])
409-
410-
if sol.success and np.all(sol.fun < self.root_tol * len(sol.x)):
411-
pybamm.logger.info("Finish calculating consistent initial conditions")
412-
return y0_consistent
413-
elif not sol.success:
414-
raise pybamm.SolverError(
415-
"Could not find consistent initial conditions: {}".format(
416-
sol.message
417-
)
418-
)
419-
else:
420-
raise pybamm.SolverError(
421-
"""
422-
Could not find consistent initial conditions: solver terminated
423-
successfully, but maximum solution error ({}) above tolerance ({})
424-
""".format(
425-
np.max(sol.fun), self.root_tol * len(sol.x)
426-
)
436+
y0_consistent = np.concatenate([y0_diff, y0_alg])
437+
pybamm.logger.info("Finish calculating consistent initial conditions")
438+
return y0_consistent
439+
elif not success:
440+
raise pybamm.SolverError(
441+
"Could not find consistent initial conditions: {}".format(message)
442+
)
443+
else:
444+
raise pybamm.SolverError(
445+
"""
446+
Could not find consistent initial conditions: solver terminated
447+
successfully, but maximum solution error ({}) above tolerance ({})
448+
""".format(
449+
np.max(fun), self.root_tol * len(y0_alg)
427450
)
451+
)
428452

429453
def solve(self, model, t_eval, external_variables=None, inputs=None):
430454
"""

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def load_version():
492492
# List of dependencies
493493
install_requires=[
494494
"numpy>=1.16",
495-
"scipy>=1.0",
495+
"scipy>=1.3",
496496
"pandas>=0.24",
497497
"anytree>=2.4.3",
498498
"autograd>=1.2",

tests/unit/test_solvers/test_base_solver.py

+48-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#
22
# Tests for the Base Solver class
33
#
4+
import casadi
45
import pybamm
56
import numpy as np
67
from scipy.sparse import csr_matrix
@@ -49,28 +50,47 @@ def test_ode_solver_fail_with_dae(self):
4950
def test_find_consistent_initial_conditions(self):
5051
# Simple system: a single algebraic equation
5152
class ScalarModel:
52-
concatenated_initial_conditions = np.array([[2]])
53-
jac_algebraic_eval = None
54-
timescale = 1
53+
def __init__(self):
54+
self.concatenated_initial_conditions = np.array([[2]])
55+
self.jac_algebraic_eval = None
56+
self.timescale = 1
57+
t = casadi.MX.sym("t")
58+
y = casadi.MX.sym("y")
59+
u = casadi.MX.sym("u")
60+
self.casadi_algebraic = casadi.Function(
61+
"alg", [t, y, u], [self.algebraic_eval(t, y)]
62+
)
5563

5664
def rhs_eval(self, t, y):
5765
return np.array([])
5866

5967
def algebraic_eval(self, t, y):
6068
return y + 2
6169

62-
solver = pybamm.BaseSolver()
70+
solver = pybamm.BaseSolver(root_method="lm")
6371
model = ScalarModel()
6472
init_cond = solver.calculate_consistent_state(model)
6573
np.testing.assert_array_equal(init_cond, -2)
74+
# with casadi
75+
solver_with_casadi = pybamm.BaseSolver(root_method="casadi")
76+
model = ScalarModel()
77+
init_cond = solver_with_casadi.calculate_consistent_state(model)
78+
np.testing.assert_array_equal(init_cond, -2)
6679

6780
# More complicated system
6881
vec = np.array([0.0, 1.0, 1.5, 2.0])
6982

7083
class VectorModel:
71-
concatenated_initial_conditions = np.zeros_like(vec)
72-
jac_algebraic_eval = None
73-
timescale = 1
84+
def __init__(self):
85+
self.concatenated_initial_conditions = np.zeros_like(vec)
86+
self.jac_algebraic_eval = None
87+
self.timescale = 1
88+
t = casadi.MX.sym("t")
89+
y = casadi.MX.sym("y", vec.size)
90+
u = casadi.MX.sym("u")
91+
self.casadi_algebraic = casadi.Function(
92+
"alg", [t, y, u], [self.algebraic_eval(t, y)]
93+
)
7494

7595
def rhs_eval(self, t, y):
7696
return y[0:1]
@@ -81,6 +101,9 @@ def algebraic_eval(self, t, y):
81101
model = VectorModel()
82102
init_cond = solver.calculate_consistent_state(model)
83103
np.testing.assert_array_almost_equal(init_cond, vec)
104+
# with casadi
105+
init_cond = solver_with_casadi.calculate_consistent_state(model)
106+
np.testing.assert_array_almost_equal(init_cond, vec)
84107

85108
# With jacobian
86109
def jac_dense(t, y):
@@ -102,9 +125,16 @@ def jac_sparse(t, y):
102125

103126
def test_fail_consistent_initial_conditions(self):
104127
class Model:
105-
concatenated_initial_conditions = np.array([2])
106-
jac_algebraic_eval = None
107-
timescale = 1
128+
def __init__(self):
129+
self.concatenated_initial_conditions = np.array([2])
130+
self.jac_algebraic_eval = None
131+
self.timescale = 1
132+
t = casadi.MX.sym("t")
133+
y = casadi.MX.sym("y")
134+
u = casadi.MX.sym("u")
135+
self.casadi_algebraic = casadi.Function(
136+
"alg", [t, y, u], [self.algebraic_eval(t, y)]
137+
)
108138

109139
def rhs_eval(self, t, y):
110140
return np.array([])
@@ -120,7 +150,14 @@ def algebraic_eval(self, t, y):
120150
"Could not find consistent initial conditions: The iteration is not making",
121151
):
122152
solver.calculate_consistent_state(Model())
123-
solver = pybamm.BaseSolver()
153+
solver = pybamm.BaseSolver(root_method="lm")
154+
with self.assertRaisesRegex(
155+
pybamm.SolverError,
156+
"Could not find consistent initial conditions: solver terminated",
157+
):
158+
solver.calculate_consistent_state(Model())
159+
# with casadi
160+
solver = pybamm.BaseSolver(root_method="casadi")
124161
with self.assertRaisesRegex(
125162
pybamm.SolverError,
126163
"Could not find consistent initial conditions: solver terminated",

tests/unit/test_solvers/test_scikits_solvers.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def test_model_solver_dae_python(self):
194194
disc.process_model(model)
195195

196196
# Solve
197-
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
197+
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm")
198198
t_eval = np.linspace(0, 1, 100)
199199
solution = solver.solve(model, t_eval)
200200
np.testing.assert_array_equal(solution.t, t_eval)
@@ -214,7 +214,7 @@ def test_model_solver_dae_bad_ics_python(self):
214214
disc.process_model(model)
215215

216216
# Solve
217-
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
217+
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm")
218218
t_eval = np.linspace(0, 1, 100)
219219
solution = solver.solve(model, t_eval)
220220
np.testing.assert_array_equal(solution.t, t_eval)
@@ -238,7 +238,7 @@ def test_model_solver_dae_events_python(self):
238238
disc.process_model(model)
239239

240240
# Solve
241-
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
241+
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm")
242242
t_eval = np.linspace(0, 5, 100)
243243
solution = solver.solve(model, t_eval)
244244
np.testing.assert_array_less(solution.y[0], 1.5)
@@ -283,7 +283,7 @@ def nonsmooth_mult(t):
283283
disc.process_model(model)
284284

285285
# Solve
286-
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
286+
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm")
287287

288288
# create two time series, one without a time point on the discontinuity,
289289
# and one with
@@ -355,7 +355,7 @@ def jacobian(t, y):
355355

356356
model.jacobian = jacobian
357357
# Solve
358-
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
358+
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm")
359359
t_eval = np.linspace(0, 1, 100)
360360
solution = solver.solve(model, t_eval)
361361
np.testing.assert_array_equal(solution.t, t_eval)
@@ -372,7 +372,7 @@ def test_solve_ode_model_with_dae_solver_python(self):
372372
disc.process_model(model)
373373

374374
# Solve
375-
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
375+
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm")
376376
t_eval = np.linspace(0, 1, 100)
377377
solution = solver.solve(model, t_eval)
378378
np.testing.assert_array_equal(solution.t, t_eval)
@@ -421,7 +421,7 @@ def test_model_step_dae_python(self):
421421
disc = get_discretisation_for_testing()
422422
disc.process_model(model)
423423

424-
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
424+
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm")
425425

426426
# Step once
427427
dt = 1
@@ -514,7 +514,10 @@ def test_model_solver_dae_inputs_events(self):
514514
disc.process_model(model)
515515

516516
# Solve
517-
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
517+
if form == "python":
518+
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8, root_method="lm")
519+
else:
520+
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
518521
t_eval = np.linspace(0, 5, 100)
519522
solution = solver.solve(model, t_eval, inputs={"rate 1": 0.1, "rate 2": 2})
520523
np.testing.assert_array_less(solution.y[0], 1.5)

0 commit comments

Comments
 (0)