Skip to content

Commit 9897494

Browse files
committed
#1477 fix for casadi manual stepper
1 parent 6e91335 commit 9897494

File tree

4 files changed

+316
-12
lines changed

4 files changed

+316
-12
lines changed

pybamm/solvers/base_solver.py

-5
Original file line numberDiff line numberDiff line change
@@ -414,11 +414,6 @@ def jacp(*args, **kwargs):
414414
(-1, 1)
415415
)
416416
func = casadi.vertcat(func, S_alg)
417-
if name == "residuals":
418-
raise NotImplementedError(
419-
"explicit forward equations not implimented for residuals"
420-
)
421-
422417
if name == "initial_conditions":
423418
if model.len_rhs == 0 or model.len_alg == 0:
424419
S_0 = casadi.jacobian(func, pS_casadi_stacked).reshape(

pybamm/solvers/casadi_solver.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
197197
# Initialize solution
198198
solution = pybamm.Solution(
199199
np.array([t]), y0, model, inputs_dict,
200-
sensitivities=explicit_sensitivities
200+
sensitivities=False,
201201
)
202202
solution.solve_time = 0
203203
solution.integration_time = 0
@@ -240,7 +240,8 @@ def _integrate(self, model, t_eval, inputs_dict=None):
240240
# halve the step size and try again.
241241
try:
242242
current_step_sol = self._run_integrator(
243-
model, y0, inputs_dict, inputs, t_window, use_grid=use_grid
243+
model, y0, inputs_dict, inputs, t_window, use_grid=use_grid,
244+
extract_sensitivities_in_solution=False,
244245
)
245246
solved = True
246247
except pybamm.SolverError:
@@ -273,6 +274,20 @@ def _integrate(self, model, t_eval, inputs_dict=None):
273274
t = t_window[-1]
274275
# update y0
275276
y0 = solution.all_ys[-1][:, -1]
277+
278+
# now we extract sensitivities from the solution
279+
if (explicit_sensitivities):
280+
# save original ys[0] and replace with separated soln
281+
# TODO: This is a dodgy hack, perhaps re-init the solution object?
282+
solution._all_ys_and_sens = [solution._all_ys[0][:]]
283+
solution._all_ys[0], solution._sensitivities = \
284+
solution._extract_explicit_sensitivities(
285+
solution.all_models[0],
286+
solution.all_ys[0],
287+
solution.all_ts[0],
288+
solution.all_inputs[0],
289+
)
290+
276291
return solution
277292

278293
def _solve_for_event(self, coarse_solution, init_event_signs):
@@ -598,12 +613,20 @@ def create_integrator(self, model, inputs, t_eval=None, use_event_switch=False):
598613

599614
return integrator
600615

601-
def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True):
616+
def _run_integrator(self, model, y0, inputs_dict,
617+
inputs, t_eval, use_grid=True,
618+
extract_sensitivities_in_solution=None,
619+
):
602620
pybamm.logger.debug("Running CasADi integrator")
603621

604622
# are we solving explicit forward equations?
605623
explicit_sensitivities = bool(self.calculate_sensitivites)
606624

625+
# by default we extract sensitivities in the solution if we
626+
# are calculating the sensitivities
627+
if extract_sensitivities_in_solution is None:
628+
extract_sensitivities_in_solution = explicit_sensitivities
629+
607630
if use_grid is True:
608631
t_eval_shifted = t_eval - t_eval[0]
609632
t_eval_shifted_rounded = np.round(t_eval_shifted, decimals=12).tobytes()
@@ -614,8 +637,9 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
614637
len_rhs = model.concatenated_rhs.size
615638

616639
# Check y0 to see if it includes sensitivities
617-
if model.len_rhs_and_alg != y0.shape[0]:
618-
len_rhs = len_rhs * (inputs.shape[0] + 1)
640+
if explicit_sensitivities:
641+
num_parameters = model.len_rhs_sens // model.len_rhs
642+
len_rhs = len_rhs * (num_parameters + 1)
619643

620644
y0_diff = y0[:len_rhs]
621645
y0_alg = y0[len_rhs:]
@@ -634,7 +658,7 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
634658
y_sol = casadi.vertcat(casadi_sol["xf"], casadi_sol["zf"])
635659
sol = pybamm.Solution(
636660
t_eval, y_sol, model, inputs_dict,
637-
sensitivities=explicit_sensitivities
661+
sensitivities=extract_sensitivities_in_solution
638662
)
639663
sol.integration_time = integration_time
640664
return sol
@@ -665,7 +689,7 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
665689

666690
sol = pybamm.Solution(
667691
t_eval, y_sol, model, inputs_dict,
668-
sensitivities=explicit_sensitivities
692+
sensitivities=extract_sensitivities_in_solution
669693
)
670694
sol.integration_time = integration_time
671695
return sol

pybamm/solvers/solution.py

+1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def __init__(
145145
# Solution now uses CasADi
146146
pybamm.citations.register("Andersson2019")
147147

148+
148149
def _extract_explicit_sensitivities(self, model, y, t_eval, inputs):
149150
"""
150151
given a model and a solution y, extracts the sensitivities

0 commit comments

Comments
 (0)