Skip to content

Commit e70e057

Browse files
#1100 fixed some solver tests
1 parent be94c59 commit e70e057

6 files changed

+277
-203
lines changed

pybamm/solvers/casadi_algebraic_solver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _integrate(self, model, t_eval, inputs=None):
7373
y0 = model.y0
7474

7575
# If y0 already satisfies the tolerance for all t then keep it
76-
if all(
76+
if self.sensitivity != "casadi" and all(
7777
np.all(abs(model.casadi_algebraic(t, y0, inputs).full()) < self.tol)
7878
for t in t_eval
7979
):

pybamm/solvers/processed_variable.py

+14-42
Original file line numberDiff line numberDiff line change
@@ -54,48 +54,21 @@ def __init__(self, base_variable, base_variable_casadi, solution, warn=True):
5454
self.base_variable = base_variable
5555
self.base_variable_casadi = base_variable_casadi
5656
self.t_sol = solution.t
57-
self.u_sol = solution.y
5857
self.mesh = base_variable.mesh
59-
self.inputs = solution.inputs
6058
self.domain = base_variable.domain
6159
self.auxiliary_domains = base_variable.auxiliary_domains
6260
self.warn = warn
6361

62+
self.inputs = solution.inputs
63+
self.symbolic_inputs = solution._symbolic_inputs
64+
65+
self.u_sol = solution.y
66+
self.y_sym = solution._y_sym
67+
6468
# Sensitivity starts off uninitialized, only set when called
6569
self._sensitivity = None
6670
self.solution_sensitivity = solution.sensitivity
6771

68-
# Special case: symbolic solution, with casadi
69-
if isinstance(solution.y, casadi.Function):
70-
# Evaluate solution at specific inputs value
71-
inputs_stacked = casadi.vertcat(*solution.inputs.values())
72-
self.u_sol = solution.y(inputs_stacked).full()
73-
# Convert variable to casadi
74-
t_MX = casadi.MX.sym("t")
75-
y_MX = casadi.MX.sym("y", self.u_sol.shape[0])
76-
# Make all inputs symbolic first for converting to casadi
77-
symbolic_inputs_dict = {
78-
name: casadi.MX.sym(name, value.shape[0])
79-
for name, value in solution.inputs.items()
80-
}
81-
82-
# The symbolic_inputs will be used for sensitivity
83-
symbolic_inputs = casadi.vertcat(*symbolic_inputs_dict.values())
84-
var_casadi = base_variable.to_casadi(
85-
t_MX, y_MX, inputs=symbolic_inputs_dict
86-
)
87-
self.base_variable_sym = casadi.Function(
88-
"variable", [t_MX, y_MX, symbolic_inputs], [var_casadi]
89-
)
90-
# Store symbolic inputs for sensitivity
91-
self.symbolic_inputs = symbolic_inputs
92-
self.y_sym = solution.y(symbolic_inputs)
93-
else:
94-
self.u_sol = solution.y
95-
self.base_variable_sym = None
96-
self.symbolic_inputs = None
97-
self.y_sym = None
98-
9972
# Set timescale
10073
self.timescale = solution.timescale_eval
10174
self.t_pts = self.t_sol * self.timescale
@@ -565,17 +538,16 @@ def sensitivity(self):
565538
return {}
566539
# Otherwise initialise and return sensitivity
567540
if self._sensitivity is None:
568-
# Check that we can compute sensitivities
569-
if self.base_variable_sym is None and self.solution_sensitivity == {}:
541+
if self.solution_sensitivity != {}:
542+
self.initialise_sensitivity_explicit_forward()
543+
elif self.y_sym is not None:
544+
self.initialise_sensitivity_casadi()
545+
else:
570546
raise ValueError(
571547
"Cannot compute sensitivities. The 'sensitivity' argument of the "
572548
"solver should be changed from 'None' to allow sensitivity "
573549
"calculations. Check solver documentation for details."
574550
)
575-
if self.base_variable_sym is None:
576-
self.initialise_sensitivity_explicit_forward()
577-
else:
578-
self.initialise_sensitivity_casadi()
579551
return self._sensitivity
580552

581553
def initialise_sensitivity_explicit_forward(self):
@@ -639,7 +611,7 @@ def initialise_0D_symbolic():
639611
for idx in range(len(self.t_sol)):
640612
t = self.t_sol[idx]
641613
u = self.y_sym[:, idx]
642-
next_entries = self.base_variable_sym(t, u, self.symbolic_inputs)
614+
next_entries = self.base_variable_casadi(t, u, self.symbolic_inputs)
643615
if idx == 0:
644616
entries = next_entries
645617
else:
@@ -653,7 +625,7 @@ def initialise_1D_symbolic():
653625
for idx in range(len(self.t_sol)):
654626
t = self.t_sol[idx]
655627
u = self.y_sym[:, idx]
656-
next_entries = self.base_variable_sym(t, u, self.symbolic_inputs)
628+
next_entries = self.base_variable_casadi(t, u, self.symbolic_inputs)
657629
if idx == 0:
658630
entries = next_entries
659631
else:
@@ -662,7 +634,7 @@ def initialise_1D_symbolic():
662634
return entries
663635

664636
inputs_stacked = casadi.vertcat(*self.inputs.values())
665-
self.base_eval = self.base_variable_sym(
637+
self.base_eval = self.base_variable_casadi(
666638
self.t_sol[0], self.u_sol[:, 0], inputs_stacked
667639
)
668640
if (

pybamm/solvers/solution.py

+32-21
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,7 @@ def __init__(
5353
inputs=None,
5454
):
5555
self.t = t
56-
if isinstance(y, casadi.DM):
57-
y = y.full()
58-
59-
# if inputs are None, initialize empty, to be populated later
60-
inputs = inputs or pybamm.FuzzyDict()
61-
self.set_inputs(inputs)
56+
self.inputs = inputs
6257

6358
# If the model has been provided, split up y into solution and sensitivity
6459
# Don't do this if the sensitivity equations have not been computed (i.e. if
@@ -107,8 +102,12 @@ def __init__(
107102
# tn_xn_p0, tn_xn_p1, ..., tn_xn_pn
108103
# 1, Extract rhs and alg sensitivities and reshape into 3D matrices
109104
# with shape (n_p, n_states, n_t)
110-
ode_sens = y[n_rhs:len_rhs_and_sens, :].reshape(n_p, n_rhs, n_t)
111-
alg_sens = y[len_rhs_and_sens + n_alg :, :].reshape(n_p, n_alg, n_t)
105+
if isinstance(y, casadi.DM):
106+
y_full = y.full()
107+
else:
108+
y_full = y
109+
ode_sens = y_full[n_rhs:len_rhs_and_sens, :].reshape(n_p, n_rhs, n_t)
110+
alg_sens = y_full[len_rhs_and_sens + n_alg :, :].reshape(n_p, n_alg, n_t)
112111
# 2. Concatenate into a single 3D matrix with shape (n_p, n_states, n_t)
113112
# i.e. along first axis
114113
full_sens_matrix = np.concatenate([ode_sens, alg_sens], axis=1)
@@ -163,8 +162,16 @@ def y(self):
163162

164163
@y.setter
165164
def y(self, y):
166-
self._y = y
167-
self._y_MX = casadi.MX.sym("y", y.shape[0])
165+
if isinstance(y, casadi.Function):
166+
self._y_fn = None
167+
inputs_stacked = casadi.vertcat(*self.inputs.values())
168+
self._y = y(inputs_stacked)
169+
self._y_sym = y(self._symbolic_inputs)
170+
else:
171+
self._y = y
172+
self._y_fn = None
173+
self._y_sym = None
174+
self._y_MX = casadi.MX.sym("y", self._y.shape[0])
168175

169176
@property
170177
def model(self):
@@ -196,8 +203,12 @@ def inputs(self):
196203
"Values of the inputs"
197204
return self._inputs
198205

199-
def set_inputs(self, inputs):
206+
@inputs.setter
207+
def inputs(self, inputs):
200208
"Updates the input values"
209+
# if inputs are None, initialize empty, to be populated later
210+
inputs = inputs or pybamm.FuzzyDict()
211+
201212
# self._inputs = {}
202213
# for name, inp in inputs.items():
203214
# # Convert number to vector of the right shape
@@ -233,13 +244,13 @@ def set_inputs(self, inputs):
233244
inp = inp[:, np.newaxis]
234245
inp = np.tile(inp, len(self.t))
235246
self._inputs[name] = inp
236-
self._all_inputs_as_MX_dict = {}
237-
for key, value in self._inputs.items():
238-
self._all_inputs_as_MX_dict[key] = casadi.MX.sym("input", value.shape[0])
247+
self._symbolic_inputs_dict = {
248+
name: casadi.MX.sym(name, value.shape[0])
249+
for name, value in self.inputs.items()
250+
}
239251

240-
self._all_inputs_as_MX = casadi.vertcat(
241-
*[p for p in self._all_inputs_as_MX_dict.values()]
242-
)
252+
# The symbolic_inputs will be used for sensitivity
253+
self._symbolic_inputs = casadi.vertcat(*self._symbolic_inputs_dict.values())
243254

244255
@property
245256
def t_event(self):
@@ -298,12 +309,12 @@ def update(self, variables):
298309
# Convert variable to casadi
299310
# Make all inputs symbolic first for converting to casadi
300311
var_sym = var_pybamm.to_casadi(
301-
self._t_MX, self._y_MX, inputs=self._all_inputs_as_MX_dict
312+
self._t_MX, self._y_MX, inputs=self._symbolic_inputs_dict
302313
)
303314

304315
var_casadi = casadi.Function(
305316
"variable",
306-
[self._t_MX, self._y_MX, self._all_inputs_as_MX],
317+
[self._t_MX, self._y_MX, self._symbolic_inputs],
307318
[var_sym],
308319
)
309320
self.model._variables_casadi[key] = var_casadi
@@ -359,8 +370,8 @@ def clear_casadi_attributes(self):
359370
"Remove casadi objects for pickling, will be computed again automatically"
360371
self._t_MX = None
361372
self._y_MX = None
362-
self._all_inputs_as_MX = None
363-
self._all_inputs_as_MX_dict = None
373+
self._symbolic_inputs = None
374+
self._symbolic_inputs_dict = None
364375

365376
def save(self, filename):
366377
"""Save the whole solution using pickle"""

tests/unit/test_solvers/test_algebraic_solver.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_wrong_solver(self):
3838

3939
def test_simple_root_find(self):
4040
# Simple system: a single algebraic equation
41-
class Model:
41+
class Model(pybamm.BaseModel):
4242
y0 = np.array([2])
4343
rhs = {}
4444
timescale_eval = 1
@@ -61,7 +61,7 @@ def algebraic_eval(self, t, y, inputs):
6161
self.assertNotEqual(solution.y, -2)
6262

6363
def test_root_find_fail(self):
64-
class Model:
64+
class Model(pybamm.BaseModel):
6565
y0 = np.array([2])
6666
rhs = {}
6767
timescale_eval = 1
@@ -92,7 +92,7 @@ def test_with_jacobian(self):
9292
A = np.array([[4, 3], [1, -1]])
9393
b = np.array([0, 7])
9494

95-
class Model:
95+
class Model(pybamm.BaseModel):
9696
y0 = np.zeros(2)
9797
rhs = {}
9898
timescale_eval = 1

0 commit comments

Comments
 (0)