Skip to content

Commit c2c1199

Browse files
#775 get external test working
1 parent 589ec31 commit c2c1199

File tree

4 files changed

+79
-10
lines changed

4 files changed

+79
-10
lines changed

pybamm/solvers/base_solver.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def atol(self):
4949
def atol(self, value):
5050
self._atol = value
5151

52-
def solve(self, model, t_eval, inputs=None):
52+
def solve(self, model, t_eval, external_variables=None, inputs=None):
5353
"""
5454
Execute the solver setup and calculate the solution of the model at
5555
specified times.
@@ -61,6 +61,9 @@ def solve(self, model, t_eval, inputs=None):
6161
initial_conditions
6262
t_eval : numeric type
6363
The times at which to compute the solution
64+
external_variables : dict
65+
A dictionary of external variables and their corresponding
66+
values at the current time
6467
inputs : dict, optional
6568
Any input parameters to pass to the model when solving
6669
@@ -80,6 +83,8 @@ def solve(self, model, t_eval, inputs=None):
8083
timer = pybamm.Timer()
8184
start_time = timer.time()
8285
inputs = inputs or {}
86+
self.y_pad = np.zeros((model.y_length - model.external_start, 1))
87+
self.set_external_variables(model, external_variables)
8388
if model.convert_to_format == "casadi" or isinstance(self, pybamm.CasadiSolver):
8489
self.set_up_casadi(model, inputs)
8590
else:

pybamm/solvers/casadi_solver.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
self.extra_options = extra_options
6969
self.name = "CasADi solver ({}) with '{}' mode".format(method, mode)
7070

71-
def solve(self, model, t_eval, inputs=None):
71+
def solve(self, model, t_eval, external_variables=None, inputs=None):
7272
"""
7373
Execute the solver setup and calculate the solution of the model at
7474
specified times.
@@ -80,6 +80,9 @@ def solve(self, model, t_eval, inputs=None):
8080
initial_conditions
8181
t_eval : numeric type
8282
The times at which to compute the solution
83+
external_variables : dict
84+
A dictionary of external variables and their corresponding
85+
values at the current time
8386
inputs : dict, optional
8487
Any input parameters to pass to the model when solving
8588
@@ -93,11 +96,15 @@ def solve(self, model, t_eval, inputs=None):
9396
"""
9497
if self.mode == "fast":
9598
# Solve model normally by calling the solve method from parent class
96-
return super().solve(model, t_eval, inputs=inputs)
99+
return super().solve(
100+
model, t_eval, external_variables=external_variables, inputs=inputs
101+
)
97102
elif model.events == {}:
98103
pybamm.logger.info("No events found, running fast mode")
99104
# Solve model normally by calling the solve method from parent class
100-
return super().solve(model, t_eval, inputs=inputs)
105+
return super().solve(
106+
model, t_eval, external_variables=external_variables, inputs=inputs
107+
)
101108
elif self.mode == "safe":
102109
# Step-and-check
103110
timer = pybamm.Timer()
@@ -122,7 +129,12 @@ def solve(self, model, t_eval, inputs=None):
122129
# different to t_eval, but shouldn't matter too much as it should
123130
# only happen near events.
124131
try:
125-
current_step_sol = self.step(model, dt, inputs=inputs)
132+
current_step_sol = self.step(
133+
model,
134+
dt,
135+
external_variables=external_variables,
136+
inputs=inputs,
137+
)
126138
solved = True
127139
except pybamm.SolverError:
128140
dt /= 2
@@ -229,6 +241,11 @@ def integrate_casadi(self, rhs, algebraic, y0, t_eval, inputs=None):
229241
Any input parameters to pass to the model when solving
230242
"""
231243
inputs = inputs or {}
244+
if self.y_ext is None:
245+
y_ext = np.array([])
246+
else:
247+
y_ext = self.y_ext
248+
232249
options = {
233250
"grid": t_eval,
234251
"reltol": self.rtol,
@@ -242,16 +259,15 @@ def integrate_casadi(self, rhs, algebraic, y0, t_eval, inputs=None):
242259
# set up and solve
243260
t = casadi.MX.sym("t")
244261
u = casadi.vertcat(*[x for x in inputs.values()])
245-
y_diff = self.y_diff
262+
y0_w_ext = casadi.vertcat(y0, y_ext[len(y0) :])
263+
y_diff = casadi.MX.sym("y_diff", rhs(0, y0_w_ext, u).shape[0])
246264
problem = {"t": t, "x": y_diff}
247265
if algebraic is None:
248-
y_casadi_w_ext = casadi.vertcat(y_diff, self.y_ext[y_diff.shape[0] :])
266+
y_casadi_w_ext = casadi.vertcat(y_diff, y_ext[len(y0) :])
249267
problem.update({"ode": rhs(t, y_casadi_w_ext, u)})
250268
else:
251269
y_alg = self.y_alg
252-
y_casadi_w_ext = casadi.vertcat(
253-
y_diff, y_alg, self.y_ext[y_diff.shape[0] + y_alg.shape[0] :]
254-
)
270+
y_casadi_w_ext = casadi.vertcat(y_diff, y_alg, y_ext[len(y0) :])
255271
problem.update(
256272
{
257273
"z": y_alg,

tests/unit/test_solvers/test_casadi_solver.py

+24
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,30 @@ def test_model_solver_with_inputs(self):
214214
np.testing.assert_array_equal(solution.t, t_eval[: len(solution.t)])
215215
np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t), rtol=1e-06)
216216

217+
def test_model_solver_with_external(self):
218+
# Create model
219+
model = pybamm.BaseModel()
220+
domain = ["negative electrode", "separator", "positive electrode"]
221+
var1 = pybamm.Variable("var1", domain=domain)
222+
var2 = pybamm.Variable("var2", domain=domain)
223+
model.rhs = {var1: -var2}
224+
model.initial_conditions = {var1: 1}
225+
model.external_variables = [var2]
226+
model.variables = {"var1": var1, "var2": var2}
227+
# No need to set parameters; can use base discretisation (no spatial
228+
# operators)
229+
230+
# create discretisation
231+
mesh = get_mesh_for_testing()
232+
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
233+
disc = pybamm.Discretisation(mesh, spatial_methods)
234+
disc.process_model(model)
235+
# Solve
236+
solver = pybamm.CasadiSolver(rtol=1e-8, atol=1e-8)
237+
t_eval = np.linspace(0, 10, 100)
238+
solution = solver.solve(model, t_eval, external_variables={"var2": 0.5})
239+
np.testing.assert_allclose(solution.y[0], 1 - 0.5 * solution.t, rtol=1e-06)
240+
217241
def test_model_solver_with_non_identity_mass(self):
218242
model = pybamm.BaseModel()
219243
var1 = pybamm.Variable("var1", domain="negative electrode")

tests/unit/test_solvers/test_scikits_solvers.py

+24
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,30 @@ def test_model_solver_dae_inputs_events(self):
745745
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t))
746746
np.testing.assert_allclose(solution.y[-1], 2 * np.exp(0.1 * solution.t))
747747

748+
def test_model_solver_dae__with_external(self):
749+
# Create model
750+
model = pybamm.BaseModel()
751+
domain = ["negative electrode", "separator", "positive electrode"]
752+
var1 = pybamm.Variable("var1", domain=domain)
753+
var2 = pybamm.Variable("var2", domain=domain)
754+
model.rhs = {var1: -var2}
755+
model.initial_conditions = {var1: 1}
756+
model.external_variables = [var2]
757+
model.variables = {"var1": var1, "var2": var2}
758+
# No need to set parameters; can use base discretisation (no spatial
759+
# operators)
760+
761+
# create discretisation
762+
mesh = get_mesh_for_testing()
763+
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
764+
disc = pybamm.Discretisation(mesh, spatial_methods)
765+
disc.process_model(model)
766+
# Solve
767+
solver = pybamm.ScikitsDaeSolver(rtol=1e-8, atol=1e-8)
768+
t_eval = np.linspace(0, 10, 100)
769+
solution = solver.solve(model, t_eval, external_variables={"var2": 0.5})
770+
np.testing.assert_allclose(solution.y[0], 1 - 0.5 * solution.t, rtol=1e-06)
771+
748772
def test_solve_ode_model_with_dae_solver_casadi(self):
749773
model = pybamm.BaseModel()
750774
model.convert_to_format = "casadi"

0 commit comments

Comments
 (0)