Skip to content

Commit ddb34c6

Browse files
#1100 merge 1221
2 parents 4c1a7af + 0d6184b commit ddb34c6

File tree

22 files changed

+415
-240
lines changed

22 files changed

+415
-240
lines changed

examples/scripts/DFN.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
# solve model
3232
t_eval = np.linspace(0, 3600, 100)
33-
solver = pybamm.CasadiSolver(mode="fast", atol=1e-6, rtol=1e-3)
33+
solver = pybamm.CasadiSolver(mode="safe", atol=1e-6, rtol=1e-3)
3434
solution = solver.solve(model, t_eval)
3535

3636
# plot

examples/scripts/experimental_protocols/cccv.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
import pybamm
55
import matplotlib.pyplot as plt
66

7-
pybamm.set_logging_level("INFO")
7+
pybamm.set_logging_level("DEBUG")
88
experiment = pybamm.Experiment(
99
[
10-
"Discharge at C/10 for 10 hours or until 3.3 V",
10+
"Discharge at C/1 for 1 hours or until 3.3 V",
1111
"Rest for 1 hour",
1212
"Charge at 1 A until 4.1 V",
1313
"Hold at 4.1 V until 50 mA",
1414
"Rest for 1 hour",
1515
]
1616
* 3
1717
)
18-
model = pybamm.lithium_ion.DFN()
18+
model = pybamm.lithium_ion.SPM()
1919
sim = pybamm.Simulation(model, experiment=experiment, solver=pybamm.CasadiSolver())
2020
sim.solve()
2121

pybamm/models/base_model.py

+36-4
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(self, name="Unnamed model"):
109109
self.external_variables = []
110110
self._parameters = None
111111
self._input_parameters = None
112+
self._variables_casadi = {}
112113

113114
# Default behaviour is to use the jacobian and simplify
114115
self.use_jacobian = True
@@ -117,6 +118,7 @@ def __init__(self, name="Unnamed model"):
117118

118119
# Model is not initially discretised
119120
self.is_discretised = False
121+
self.y_slices = None
120122

121123
# Default timescale is 1 second
122124
self.timescale = pybamm.Scalar(1)
@@ -340,6 +342,14 @@ def new_empty_copy(self):
340342
new_model.convert_to_format = self.convert_to_format
341343
new_model.timescale = self.timescale
342344
new_model.length_scales = self.length_scales
345+
346+
# Variables from discretisation
347+
new_model.is_discretised = self.is_discretised
348+
new_model.y_slices = self.y_slices
349+
new_model.concatenated_rhs = self.concatenated_rhs
350+
new_model.concatenated_algebraic = self.concatenated_algebraic
351+
new_model.concatenated_initial_conditions = self.concatenated_initial_conditions
352+
343353
return new_model
344354

345355
def new_copy(self):
@@ -425,6 +435,31 @@ def set_initial_conditions_from(self, solution, inplace=True):
425435
"Variable must have type 'Variable' or 'Concatenation'"
426436
)
427437

438+
# Also update the concatenated initial conditions if the model is already
439+
# discretised
440+
if model.is_discretised:
441+
# Unpack slices for sorting
442+
y_slices = {var.id: slce for var, slce in model.y_slices.items()}
443+
slices = []
444+
for symbol in model.initial_conditions.keys():
445+
if isinstance(symbol, pybamm.Concatenation):
446+
# must append the slice for the whole concatenation, so that
447+
# equations get sorted correctly
448+
slices.append(
449+
slice(
450+
y_slices[symbol.children[0].id][0].start,
451+
y_slices[symbol.children[-1].id][0].stop,
452+
)
453+
)
454+
else:
455+
slices.append(y_slices[symbol.id][0])
456+
equations = list(model.initial_conditions.values())
457+
# sort equations according to slices
458+
sorted_equations = [eq for _, eq in sorted(zip(slices, equations))]
459+
model.concatenated_initial_conditions = pybamm.NumpyConcatenation(
460+
*sorted_equations
461+
)
462+
428463
return model
429464

430465
def check_and_combine_dict(self, dict1, dict2):
@@ -901,10 +936,7 @@ def default_spatial_methods(self):
901936
@property
902937
def default_solver(self):
903938
"Return default solver based on whether model is ODE model or DAE model"
904-
if len(self.algebraic) == 0:
905-
return pybamm.ScipySolver()
906-
else:
907-
return pybamm.CasadiSolver(mode="safe")
939+
return pybamm.CasadiSolver(mode="safe")
908940

909941

910942
# helper functions for finding symbols

pybamm/simulation.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -532,17 +532,16 @@ def get_variable_array(self, *variables):
532532
arrays.
533533
"""
534534

535-
variable_arrays = [
536-
self.built_model.variables[var].evaluate(
537-
self.solution.t[-1], self.solution.y[:, -1]
538-
)
539-
for var in variables
540-
]
541-
542-
if len(variable_arrays) == 1:
543-
return variable_arrays[0]
544-
else:
545-
return tuple(variable_arrays)
535+
variable_arrays = {}
536+
for var in variables:
537+
processed_var = self.solution[var].data
538+
if processed_var.ndim == 1:
539+
variable_arrays[var] = processed_var[-1]
540+
elif processed_var.ndim == 2:
541+
variable_arrays[var] = processed_var[:, -1]
542+
elif processed_var.ndim == 3:
543+
variable_arrays[var] = processed_var[:, :, -1]
544+
return variable_arrays
546545

547546
def plot(self, output_variables=None, quick_plot_vars=None, **kwargs):
548547
"""
@@ -693,6 +692,8 @@ def save(self, filename):
693692
and self._solver.integrator_specs != {}
694693
):
695694
self._solver.integrator_specs = {}
695+
if self.solution is not None:
696+
self.solution.clear_casadi_attributes()
696697
with open(filename, "wb") as f:
697698
pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
698699

pybamm/solvers/base_solver.py

-1
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,6 @@ def __init__(self, function, name, model):
10581058
self.timescale = self.model.timescale_eval
10591059

10601060
def __call__(self, t, y, inputs):
1061-
y = y.reshape(-1, 1)
10621061
if self.name in ["rhs", "algebraic", "residuals"]:
10631062
pybamm.logger.debug(
10641063
"Evaluating {} for {} at t={}".format(

pybamm/solvers/casadi_solver.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,6 @@ def _integrate(self, model, t_eval, inputs=None):
155155
return solution
156156
elif self.mode in ["safe", "safe without grid"]:
157157
y0 = model.y0
158-
if isinstance(y0, casadi.DM):
159-
y0 = y0.full().flatten()
160158
# Step-and-check
161159
t = t_eval[0]
162160
t_f = t_eval[-1]
@@ -174,7 +172,7 @@ def _integrate(self, model, t_eval, inputs=None):
174172
self.create_integrator(model, inputs_dict)
175173
# Initialize solution
176174
solution = pybamm.Solution(
177-
np.array([t]), y0[:, np.newaxis], model=model, inputs=inputs_dict
175+
np.array([t]), y0, model=model, inputs=inputs_dict
178176
)
179177
solution.solve_time = 0
180178
solution.integration_time = 0
@@ -457,7 +455,7 @@ def _run_integrator(self, model, y0, inputs_dict, t_eval):
457455
x0=y0_diff, z0=y0_alg, p=inputs_eval, **self.extra_options_call
458456
)
459457
integration_time = timer.time()
460-
y_sol = np.concatenate([sol["xf"].full(), sol["zf"].full()])
458+
y_sol = casadi.vertcat(sol["xf"], sol["zf"])
461459
sol = pybamm.Solution(t_eval, y_sol)
462460
sol.integration_time = integration_time
463461
return sol

pybamm/solvers/dummy_solver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _integrate(self, model, t_eval, inputs=None):
3434
"""
3535
y_sol = np.zeros((1, t_eval.size))
3636
sol = pybamm.Solution(
37-
t_eval, y_sol, termination="final time", model=model, inputs=inputs_dict
37+
t_eval, y_sol, termination="final time", model=model, inputs=inputs
3838
)
3939
sol.integration_time = 0
4040
return sol

pybamm/solvers/jax_solver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def _integrate(self, model, t_eval, inputs=None):
203203
t_event = None
204204
y_event = onp.array(None)
205205
sol = pybamm.Solution(
206-
t_eval, y, t_event, y_event, termination, model=model, inputs=inputs_dict
206+
t_eval, y, t_event, y_event, termination, model=model, inputs=inputs
207207
)
208208
sol.integration_time = integration_time
209209
return sol

pybamm/solvers/processed_variable.py

+25-65
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,25 @@ class ProcessedVariable(object):
4040
variable. Note that this can be any kind of node in the expression tree, not
4141
just a :class:`pybamm.Variable`.
4242
When evaluated, returns an array of size (m,n)
43+
base_variable_casadi : :class:`casadi.Function`
44+
A casadi function. When evaluated, returns the same thing as
45+
`base_Variable.evaluate` (but more efficiently).
4346
solution : :class:`pybamm.Solution`
4447
The solution object to be used to create the processed variables
45-
known_evals : dict
46-
Dictionary of known evaluations, to be used to speed up finding the solution
4748
warn : bool, optional
4849
Whether to raise warnings when trying to evaluate time and length scales.
4950
Default is True.
5051
"""
5152

52-
def __init__(self, base_variable, solution, known_evals=None, warn=True):
53+
def __init__(self, base_variable, base_variable_casadi, solution, warn=True):
5354
self.base_variable = base_variable
55+
self.base_variable_casadi = base_variable_casadi
5456
self.t_sol = solution.t
5557
self.u_sol = solution.y
5658
self.mesh = base_variable.mesh
5759
self.inputs = solution.inputs
5860
self.domain = base_variable.domain
5961
self.auxiliary_domains = base_variable.auxiliary_domains
60-
self.known_evals = known_evals
6162
self.warn = warn
6263

6364
# Sensitivity starts off uninitialized, only set when called
@@ -104,19 +105,10 @@ def __init__(self, base_variable, solution, known_evals=None, warn=True):
104105
self.length_scales = solution.length_scales_eval
105106

106107
# Evaluate base variable at initial time
107-
if self.known_evals:
108-
self.base_eval, self.known_evals[solution.t[0]] = base_variable.evaluate(
109-
self.t_sol[0],
110-
self.u_sol[:, 0],
111-
inputs={name: inp[:, 0] for name, inp in solution.inputs.items()},
112-
known_evals=self.known_evals[solution.t[0]],
113-
)
114-
else:
115-
self.base_eval = base_variable.evaluate(
116-
solution.t[0],
117-
solution.y[:, 0],
118-
inputs={name: inp[:, 0] for name, inp in solution.inputs.items()},
119-
)
108+
inputs = casadi.vertcat(*[inp[:, 0] for inp in self.inputs.values()])
109+
self.base_eval = self.base_variable_casadi(
110+
solution.t[0], solution.y[:, 0], inputs
111+
).full()
120112

121113
# handle 2D (in space) finite element variables differently
122114
if (
@@ -164,13 +156,8 @@ def initialise_0D(self):
164156
for idx in range(len(self.t_sol)):
165157
t = self.t_sol[idx]
166158
u = self.u_sol[:, idx]
167-
inputs = {name: inp[:, idx] for name, inp in self.inputs.items()}
168-
if self.known_evals:
169-
entries[idx], self.known_evals[t] = self.base_variable.evaluate(
170-
t, u, inputs=inputs, known_evals=self.known_evals[t]
171-
)
172-
else:
173-
entries[idx] = self.base_variable.evaluate(t, u, inputs=inputs)
159+
inputs = casadi.vertcat(*[inp[:, idx] for inp in self.inputs.values()])
160+
entries[idx] = self.base_variable_casadi(t, u, inputs).full()[0, 0]
174161

175162
# set up interpolation
176163
if len(self.t_sol) == 1:
@@ -200,15 +187,8 @@ def initialise_1D(self, fixed_t=False):
200187
for idx in range(len(self.t_sol)):
201188
t = self.t_sol[idx]
202189
u = self.u_sol[:, idx]
203-
inputs = {name: inp[:, idx] for name, inp in self.inputs.items()}
204-
if self.known_evals:
205-
eval_and_known_evals = self.base_variable.evaluate(
206-
t, u, inputs=inputs, known_evals=self.known_evals[t]
207-
)
208-
entries[:, idx] = eval_and_known_evals[0][:, 0]
209-
self.known_evals[t] = eval_and_known_evals[1]
210-
else:
211-
entries[:, idx] = self.base_variable.evaluate(t, u, inputs=inputs)[:, 0]
190+
inputs = casadi.vertcat(*[inp[:, idx] for inp in self.inputs.values()])
191+
entries[:, idx] = self.base_variable_casadi(t, u, inputs).full()[:, 0]
212192

213193
# Get node and edge values
214194
nodes = self.mesh.nodes
@@ -310,23 +290,12 @@ def initialise_2D(self):
310290
for idx in range(len(self.t_sol)):
311291
t = self.t_sol[idx]
312292
u = self.u_sol[:, idx]
313-
inputs = {name: inp[:, idx] for name, inp in self.inputs.items()}
314-
if self.known_evals:
315-
eval_and_known_evals = self.base_variable.evaluate(
316-
t, u, inputs=inputs, known_evals=self.known_evals[t]
317-
)
318-
entries[:, :, idx] = np.reshape(
319-
eval_and_known_evals[0],
320-
[first_dim_size, second_dim_size],
321-
order="F",
322-
)
323-
self.known_evals[t] = eval_and_known_evals[1]
324-
else:
325-
entries[:, :, idx] = np.reshape(
326-
self.base_variable.evaluate(t, u, inputs=inputs),
327-
[first_dim_size, second_dim_size],
328-
order="F",
329-
)
293+
inputs = casadi.vertcat(*[inp[:, idx] for inp in self.inputs.values()])
294+
entries[:, :, idx] = np.reshape(
295+
self.base_variable_casadi(t, u, inputs).full(),
296+
[first_dim_size, second_dim_size],
297+
order="F",
298+
)
330299

331300
# add points outside first dimension domain for extrapolation to
332301
# boundaries
@@ -463,22 +432,13 @@ def initialise_2D_scikit_fem(self):
463432
for idx in range(len(self.t_sol)):
464433
t = self.t_sol[idx]
465434
u = self.u_sol[:, idx]
466-
inputs = {name: inp[:, idx] for name, inp in self.inputs.items()}
435+
inputs = casadi.vertcat(*[inp[:, idx] for inp in self.inputs.values()])
467436

468-
if self.known_evals:
469-
eval_and_known_evals = self.base_variable.evaluate(
470-
t, u, inputs=inputs, known_evals=self.known_evals[t]
471-
)
472-
entries[:, :, idx] = np.reshape(
473-
eval_and_known_evals[0], [len_y, len_z], order="F"
474-
)
475-
self.known_evals[t] = eval_and_known_evals[1]
476-
else:
477-
entries[:, :, idx] = np.reshape(
478-
self.base_variable.evaluate(t, u, inputs=inputs),
479-
[len_y, len_z],
480-
order="F",
481-
)
437+
entries[:, :, idx] = np.reshape(
438+
self.base_variable_casadi(t, u, inputs).full(),
439+
[len_y, len_z],
440+
order="F",
441+
)
482442

483443
# assign attributes for reference
484444
self.entries = entries

0 commit comments

Comments
 (0)