Skip to content

Commit d9ff546

Browse files
committed
#1477 fix algebraic solver
1 parent 72560c5 commit d9ff546

File tree

4 files changed

+22
-15
lines changed

4 files changed

+22
-15
lines changed

pybamm/solvers/base_solver.py

-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,6 @@ def jacp(*args, **kwargs):
360360
# Add sensitivity vectors to the rhs and algebraic equations
361361
jacp = None
362362
if calculate_sensitivites_explicit:
363-
print('CASADI EXPLICIT', name, model.len_rhs)
364363
# The formulation is as per Park, S., Kato, D., Gima, Z., Klein, R.,
365364
# & Moura, S. (2018). Optimal experimental design for
366365
# parameterization of an electrochemical lithium-ion battery model.

pybamm/solvers/casadi_algebraic_solver.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def _integrate(self, model, t_eval, inputs_dict=None):
6161
"""
6262
# Record whether there are any symbolic inputs
6363
inputs_dict = inputs_dict or {}
64+
has_symbolic_inputs = any(
65+
isinstance(v, casadi.MX) for v in inputs_dict.values()
66+
)
6467
symbolic_inputs = casadi.vertcat(
6568
*[v for v in inputs_dict.values() if isinstance(v, casadi.MX)]
6669
)
@@ -70,22 +73,29 @@ def _integrate(self, model, t_eval, inputs_dict=None):
7073

7174
y0 = model.y0
7275

76+
# If y0 already satisfies the tolerance for all t then keep it
77+
if has_symbolic_inputs is False and all(
78+
np.all(abs(model.casadi_algebraic(t, y0, inputs).full()) < self.tol)
79+
for t in t_eval
80+
):
81+
pybamm.logger.debug("Keeping same solution at all times")
82+
return pybamm.Solution(
83+
t_eval, y0, model, inputs_dict, termination="success"
84+
)
85+
7386
# The casadi algebraic solver can read rhs equations, but leaves them unchanged
7487
# i.e. the part of the solution vector that corresponds to the differential
7588
# equations will be equal to the initial condition provided. This allows this
7689
# solver to be used for initialising the DAE solvers
7790
if model.rhs == {}:
78-
print('no rhs')
7991
len_rhs = 0
8092
y0_diff = casadi.DM()
8193
y0_alg = y0
8294
else:
8395
# Check y0 to see if it includes sensitivities
8496
if model.len_rhs_and_alg == y0.shape[0]:
85-
print('doesnt include sens')
8697
len_rhs = model.len_rhs
8798
else:
88-
print('includes sens', inputs.shape[0])
8999
len_rhs = model.len_rhs * (inputs.shape[0] + 1)
90100
y0_diff = y0[:len_rhs]
91101
y0_alg = y0[len_rhs:]
@@ -159,7 +169,8 @@ def _integrate(self, model, t_eval, inputs_dict=None):
159169
for idx, t in enumerate(t_eval):
160170
# Evaluate algebraic with new t and previous y0, if it's already close
161171
# enough then keep it
162-
if np.all(
172+
# We can't do this if there are symbolic inputs
173+
if has_symbolic_inputs is False and np.all(
163174
abs(model.casadi_algebraic(t, y0, inputs).full()) < self.tol
164175
):
165176
pybamm.logger.debug(
@@ -171,7 +182,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
171182
y_alg = casadi.horzcat(y_alg, y0_alg)
172183
# Otherwise calculate new y_sol
173184
else:
174-
t_y0_diff_inputs = casadi.vertcat(t, y0_diff, inputs)
185+
t_y0_diff_inputs = casadi.vertcat(t, y0_diff, symbolic_inputs)
175186
# Solve
176187
try:
177188
timer.reset()
@@ -187,9 +198,11 @@ def _integrate(self, model, t_eval, inputs_dict=None):
187198
message = err.args[0]
188199
fun = None
189200

190-
# check the function is below the tol
201+
# If there are no symbolic inputs, check the function is below the tol
202+
# Skip this check if there are symbolic inputs
191203
if success and (
192-
not any(np.isnan(fun)) and np.all(casadi.fabs(fun) < self.tol)
204+
has_symbolic_inputs is True
205+
or (not any(np.isnan(fun)) and np.all(casadi.fabs(fun) < self.tol))
193206
):
194207
# update initial guess for the next iteration
195208
y0_alg = y_alg_sol

pybamm/solvers/casadi_solver.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def integer_bisect():
442442
np.array([t_event]),
443443
y_event[:, np.newaxis],
444444
"event",
445-
sensitivities=explicit_sensitivities
445+
sensitivities=bool(self.calculate_sensitivites)
446446
)
447447
solution.integration_time = (
448448
coarse_solution.integration_time + dense_step_sol.integration_time
@@ -665,11 +665,6 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
665665
y_sol = y_diff
666666
else:
667667
y_sol = casadi.vertcat(y_diff, y_alg)
668-
# If doing sensitivity, return the solution as a function of the inputs
669-
if self.sensitivity == "casadi":
670-
y_sol = casadi.Function("y_sol", [symbolic_inputs], [y_sol])
671-
# Save the solution, can just reuse and change the inputs
672-
self.y_sols[model] = y_sol
673668

674669
sol = pybamm.Solution(
675670
t_eval, y_sol, model, inputs_dict,

tests/unit/test_solvers/test_casadi_solver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def test_solve_sensitivity_scalar_var_scalar_input(self):
592592
t_eval = np.linspace(0, 1, 80)
593593
solution = solver.solve(
594594
model, t_eval, inputs={"p": 0.1, "q": 2, "r": -1, "s": 0.5},
595-
sensitivity=True,
595+
calculate_sensitivities=True,
596596
)
597597
np.testing.assert_allclose(solution.y[0], -1 + 0.2 * solution.t)
598598
np.testing.assert_allclose(

0 commit comments

Comments
 (0)