Skip to content

Commit e1c51a2

Browse files
Merge pull request #1416 from pybamm-team/issue-1414-experiment-bug
#1414 fix bug in set_initial_conditions_from
2 parents 900e52e + d41d557 commit e1c51a2

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

pybamm/models/base_model.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,6 @@ def set_initial_conditions_from(self, solution, inplace=True):
384384

385385
if isinstance(solution, pybamm.Solution):
386386
solution = solution.last_state
387-
else:
388-
solution = pybamm.FuzzyDict(solution)
389387
for var, equation in model.initial_conditions.items():
390388
if isinstance(var, pybamm.Variable):
391389
try:
@@ -404,7 +402,7 @@ def set_initial_conditions_from(self, solution, inplace=True):
404402
elif final_state.ndim == 2:
405403
final_state_eval = final_state[:, -1]
406404
elif final_state.ndim == 3:
407-
final_state_eval = final_state[:, :, -1].flatten()
405+
final_state_eval = final_state[:, :, -1].flatten(order="F")
408406
else:
409407
raise NotImplementedError("Variable must be 0D, 1D, or 2D")
410408
model.initial_conditions[var] = pybamm.Vector(final_state_eval)

pybamm/solvers/casadi_algebraic_solver.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def _integrate(self, model, t_eval, inputs_dict=None):
220220
y_diff = casadi.horzcat(*[y0_diff] * len(t_eval))
221221
y_sol = casadi.vertcat(y_diff, y_alg)
222222
# Return solution object (no events, so pass None to t_event, y_event)
223-
sol = pybamm.Solution(t_eval, y_sol, model, inputs_dict, termination="success")
223+
sol = pybamm.Solution(
224+
[t_eval], y_sol, model, inputs_dict, termination="success"
225+
)
224226
sol.integration_time = integration_time
225227
return sol

tests/unit/test_experiments/test_simulation_with_experiment.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#
22
# Test setting up a simulation with an experiment
33
#
4+
import casadi
45
import pybamm
56
import numpy as np
67
import unittest
@@ -82,12 +83,22 @@ def test_run_experiment(self):
8283
)
8384
]
8485
)
85-
model = pybamm.lithium_ion.SPM()
86+
model = pybamm.lithium_ion.DFN()
8687
sim = pybamm.Simulation(model, experiment=experiment)
8788
sol = sim.solve()
8889
self.assertEqual(sol.termination, "final time")
8990
self.assertEqual(len(sol.cycles), 1)
9091

92+
for i, step in enumerate(sol.cycles[0].steps[:-1]):
93+
len_rhs = sol.all_models[0].concatenated_rhs.size
94+
y_left = step.all_ys[-1][:len_rhs, -1]
95+
if isinstance(y_left, casadi.DM):
96+
y_left = y_left.full()
97+
y_right = sol.cycles[0].steps[i + 1].all_ys[0][:len_rhs, 0]
98+
if isinstance(y_right, casadi.DM):
99+
y_right = y_right.full()
100+
np.testing.assert_array_equal(y_left.flatten(), y_right.flatten())
101+
91102
# Solve again starting from solution
92103
sol2 = sim.solve(starting_solution=sol)
93104
self.assertEqual(sol2.termination, "final time")

0 commit comments

Comments
 (0)