Skip to content

Commit ac94921

Browse files
committed
#1477 took out sensitivity=casadi option, take 2
1 parent 3ebec30 commit ac94921

File tree

2 files changed

+7
-49
lines changed

2 files changed

+7
-49
lines changed

pybamm/solvers/casadi_algebraic_solver.py

+6-48
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,11 @@ class CasadiAlgebraicSolver(pybamm.BaseSolver):
2121
Any options to pass to the CasADi rootfinder.
2222
Please consult `CasADi documentation <https://tinyurl.com/y7hrxm7d>`_ for
2323
details.
24-
sensitivity : str, optional
25-
Whether (and how) to calculate sensitivities when solving. Options are:
26-
27-
- None: no sensitivities
28-
- "explicit forward": explicitly formulate the sensitivity equations. \
29-
See :class:`pybamm.BaseSolver`
30-
- "casadi": use casadi to differentiate through the rootfinding operator
3124
3225
"""
3326

34-
def __init__(self, tol=1e-6, extra_options=None, sensitivity=None):
35-
super().__init__(sensitivity=sensitivity)
27+
def __init__(self, tol=1e-6, extra_options=None):
28+
super().__init__()
3629
self.tol = tol
3730
self.name = "CasADi algebraic solver"
3831
self.algebraic_solver = True
@@ -76,18 +69,6 @@ def _integrate(self, model, t_eval, inputs_dict=None):
7669
inputs = casadi.vertcat(*[v for v in inputs_dict.values()])
7770

7871
y0 = model.y0
79-
print('algebraic', y0)
80-
81-
# If y0 already satisfies the tolerance for all t then keep it
82-
if self.sensitivity != "casadi" and all(
83-
np.all(abs(model.casadi_algebraic(t, y0, inputs).full()) < self.tol)
84-
for t in t_eval
85-
):
86-
print('keeping soln', y0.full())
87-
pybamm.logger.debug("Keeping same solution at all times")
88-
return pybamm.Solution(
89-
t_eval, y0, model, inputs_dict, termination="success"
90-
)
9172

9273
# The casadi algebraic solver can read rhs equations, but leaves them unchanged
9374
# i.e. the part of the solution vector that corresponds to the differential
@@ -139,16 +120,6 @@ def _integrate(self, model, t_eval, inputs_dict=None):
139120
)
140121

141122
if model in self.rootfinders:
142-
if self.sensitivity == "casadi":
143-
# Reuse (symbolic) solution with new inputs
144-
y_sol = self.y_sols[model]
145-
return pybamm.Solution(
146-
t_eval,
147-
y_sol,
148-
termination="success",
149-
model=model,
150-
inputs=inputs_dict,
151-
)
152123
roots = self.rootfinders[model]
153124
else:
154125
# Set up
@@ -188,8 +159,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
188159
for idx, t in enumerate(t_eval):
189160
# Evaluate algebraic with new t and previous y0, if it's already close
190161
# enough then keep it
191-
# We can't do this if also doing sensitivity
192-
if self.sensitivity != "casadi" and np.all(
162+
if np.all(
193163
abs(model.casadi_algebraic(t, y0, inputs).full()) < self.tol
194164
):
195165
pybamm.logger.debug(
@@ -201,12 +171,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
201171
y_alg = casadi.horzcat(y_alg, y0_alg)
202172
# Otherwise calculate new y_sol
203173
else:
204-
# If doing sensitivity with casadi, evaluate with symbolic inputs
205-
# Otherwise, evaluate with actual inputs
206-
if self.sensitivity == "casadi":
207-
t_y0_diff_inputs = casadi.vertcat(t, y0_diff, symbolic_inputs)
208-
else:
209-
t_y0_diff_inputs = casadi.vertcat(t, y0_diff, inputs)
174+
t_y0_diff_inputs = casadi.vertcat(t, y0_diff, inputs)
210175
# Solve
211176
try:
212177
timer.reset()
@@ -222,11 +187,9 @@ def _integrate(self, model, t_eval, inputs_dict=None):
222187
message = err.args[0]
223188
fun = None
224189

225-
# If there are no symbolic inputs, check the function is below the tol
226-
# Skip this check if also doing sensitivity
190+
# check the function is below the tol
227191
if success and (
228-
self.sensitivity == "casadi"
229-
or (not any(np.isnan(fun)) and np.all(casadi.fabs(fun) < self.tol))
192+
not any(np.isnan(fun)) and np.all(casadi.fabs(fun) < self.tol)
230193
):
231194
# update initial guess for the next iteration
232195
y0_alg = y_alg_sol
@@ -259,11 +222,6 @@ def _integrate(self, model, t_eval, inputs_dict=None):
259222
y_diff = casadi.horzcat(*[y0_diff] * len(t_eval))
260223
y_sol = casadi.vertcat(y_diff, y_alg)
261224

262-
# If doing sensitivity, return the solution as a function of the inputs
263-
if self.sensitivity == "casadi":
264-
y_sol = casadi.Function("y_sol", [symbolic_inputs], [y_sol])
265-
# Save the solution, can just reuse and change the inputs
266-
self.y_sols[model] = y_sol
267225
# Return solution object (no events, so pass None to t_event, y_event)
268226
sol = pybamm.Solution(
269227
[t_eval], y_sol, model, inputs_dict, termination="success"

tests/unit/test_solvers/test_casadi_solver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ def test_solve_sensitivity_vector_var_scalar_input(self):
660660
solver = pybamm.CasadiSolver()
661661
t_eval = np.linspace(0, 1)
662662
solution = solver.solve(model, t_eval, inputs={"param": 7},
663-
sensitivity=["param"])
663+
sensitivities=["param"])
664664
np.testing.assert_array_almost_equal(
665665
solution["var"].data, np.tile(2 * np.exp(-7 * t_eval), (n, 1)), decimal=4,
666666
)

0 commit comments

Comments
 (0)