Skip to content

Commit 6ca02be

Browse files
committed
#1477 fix some remaining bugs with algebraic solver bounds
1 parent f8bc091 commit 6ca02be

File tree

4 files changed

+29
-9
lines changed

4 files changed

+29
-9
lines changed

pybamm/solvers/base_solver.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -613,10 +613,11 @@ def jacp(*args, **kwargs):
613613
n_inputs = model.len_rhs_sens // model.len_rhs
614614
elif model.len_alg != 0:
615615
n_inputs = model.len_alg_sens // model.len_alg
616-
model.bounds = (
617-
np.repeat(model.bounds[0], n_inputs + 1),
618-
np.repeat(model.bounds[1], n_inputs + 1),
619-
)
616+
if model.bounds[0].shape[0] < model.len_alg + model.len_alg_sens:
617+
model.bounds = (
618+
np.repeat(model.bounds[0], n_inputs + 1),
619+
np.repeat(model.bounds[1], n_inputs + 1),
620+
)
620621
if (model.mass_matrix is not None
621622
and model.mass_matrix.shape[0] == model.len_rhs_and_alg):
622623

@@ -634,6 +635,11 @@ def jacp(*args, **kwargs):
634635
)
635636
else:
636637
# take care if calculate_sensitivites used then not used
638+
if model.bounds[0].shape[0] > model.len_alg:
639+
model.bounds = (
640+
model.bounds[0][:model.len_alg],
641+
model.bounds[1][:model.len_alg],
642+
)
637643
if (model.mass_matrix is not None and
638644
model.mass_matrix.shape[0] > model.len_rhs_and_alg):
639645
if model.mass_matrix_inv is not None:

pybamm/solvers/casadi_algebraic_solver.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ def _integrate(self, model, t_eval, inputs_dict=None):
230230

231231
# Return solution object (no events, so pass None to t_event, y_event)
232232

233-
explicit_sensitivities = bool(model.calculate_sensitivities)
233+
try:
234+
explicit_sensitivities = bool(model.calculate_sensitivities)
235+
except AttributeError:
236+
explicit_sensitivities = False
237+
234238
sol = pybamm.Solution(
235239
[t_eval], y_sol, model, inputs_dict, termination="success",
236240
sensitivities=explicit_sensitivities

pybamm/solvers/casadi_solver.py

+4
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ def _integrate(self, model, t_eval, inputs_dict=None):
272272
# update y0
273273
y0 = solution.all_ys[-1][:, -1]
274274

275+
# now we extract sensitivities from the solution
276+
if (bool(model.calculate_sensitivities)):
277+
solution.sensitivities = True
278+
275279
return solution
276280

277281
def _solve_for_event(self, coarse_solution, init_event_signs):

pybamm/solvers/solution.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,8 @@ def __init__(
8181
else:
8282
self.all_inputs = all_inputs
8383

84-
# sensitivities must be a dict or bool
85-
if not isinstance(sensitivities, (bool, dict)):
86-
raise TypeError('sensitivities arg needs to be a bool or dict')
87-
self._sensitivities = sensitivities
84+
85+
self.sensitivities = sensitivities
8886

8987
self._t_event = t_event
9088
self._y_event = y_event
@@ -285,6 +283,14 @@ def sensitivities(self):
285283
self._sensitivities = {}
286284
return self._sensitivities
287285

286+
@sensitivities.setter
287+
def sensitivities(self, value):
288+
"""Updates the sensitivity"""
289+
# sensitivities must be a dict or bool
290+
if not isinstance(value, (bool, dict)):
291+
raise TypeError('sensitivities arg needs to be a bool or dict')
292+
self._sensitivities = value
293+
288294
def set_y(self):
289295
try:
290296
if isinstance(self.all_ys[0], (casadi.DM, casadi.MX)):

0 commit comments

Comments
 (0)