Skip to content

Commit 0cbe0a5

Browse files
committed
#1477 fix some bugs after merge
1 parent 836e57f commit 0cbe0a5

File tree

5 files changed

+222
-108
lines changed

5 files changed

+222
-108
lines changed

pybamm/solvers/base_solver.py

+40-27
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,15 @@ class BaseSolver(object):
3737
The tolerance to assert whether extrapolation occurs or not. Default is 0.
3838
sensitivity : str, optional
3939
Whether (and how) to calculate sensitivities when solving. Options are:
40-
- None (default): user must give the names of input parameters to calculate
41-
sensitivity via the "solve" method, the individual solver is responsible for
40+
- None (default): the individual solver is responsible for
4241
calculating the sensitivity wrt these parameters, and providing the result in
4342
the solution instance returned. At the moment this is only implemented for the
4443
IDAKLU solver.\
45-
- "explicit forward": explicitly formulate the sensitivity equations for *all*
46-
the input parameters. The formulation is as per "Park, S., Kato, D., Gima, Z., \
47-
Klein, R., & Moura, S. (2018). Optimal experimental design for parameterization\
48-
of an electrochemical lithium-ion battery model. Journal of The Electrochemical\
44+
- "explicit forward": explicitly formulate the sensitivity equations for
45+
the chosen input parameters. The formulation is as per
46+
"Park, S., Kato, D., Gima, Z., Klein, R., & Moura, S. (2018).\
47+
Optimal experimental design for parameterization of an electrochemical
48+
lithium-ion battery model. Journal of The Electrochemical\
4949
Society, 165(7), A1309.". See #1100 for details. At the moment this is only
5050
implemented using convert_to_format = 'casadi'. \
5151
- see individual solvers for other options
@@ -140,6 +140,8 @@ def copy(self):
140140
new_solver.models_set_up = {}
141141
return new_solver
142142

143+
144+
143145
def set_up(self, model, inputs=None, t_eval=None,
144146
calculate_sensitivites=False):
145147
"""Unpack model, perform checks, and calculate jacobian.
@@ -236,6 +238,9 @@ def set_up(self, model, inputs=None, t_eval=None,
236238
calculate_sensitivites = [p for p in inputs.keys()]
237239
else:
238240
calculate_sensitivites = []
241+
# save sensitivity parameters so we can identify them later on
242+
# (FYI: this is used in the Solution class)
243+
model.calculate_sensitivities = calculate_sensitivites
239244

240245
# Only allow solving explicit sensitivity equations with the casadi format for now
241246
if (
@@ -269,8 +274,13 @@ def set_up(self, model, inputs=None, t_eval=None,
269274
p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()])
270275
# sensitivity vectors
271276
if self.sensitivity == "explicit forward":
272-
S_x = casadi.MX.sym("S_x", model.len_rhs * p_casadi_stacked.shape[0])
273-
S_z = casadi.MX.sym("S_z", model.len_alg * p_casadi_stacked.shape[0])
277+
pS_casadi_stacked = casadi.vertcat(
278+
*[p_casadi[name] for name in calculate_sensitivites]
279+
)
280+
model.len_rhs_sens = model.len_rhs * pS_casadi_stacked.shape[0]
281+
model.len_alg_sens = model.len_alg * pS_casadi_stacked.shape[0]
282+
S_x = casadi.MX.sym("S_x", model.len_rhs_sens)
283+
S_z = casadi.MX.sym("S_z", model.len_alg_sens)
274284
y_and_S = casadi.vertcat(y_diff, S_x, y_alg, S_z)
275285
else:
276286
y_and_S = y_casadi
@@ -356,16 +366,16 @@ def jacp(*args, **kwargs):
356366
if name == "rhs" and model.len_rhs > 0:
357367
report("Creating sensitivity equations for rhs using CasADi")
358368
df_dx = casadi.jacobian(func, y_diff)
359-
df_dp = casadi.jacobian(func, p_casadi_stacked)
369+
df_dp = casadi.jacobian(func, pS_casadi_stacked)
360370
S_x_mat = S_x.reshape(
361-
(model.len_rhs, p_casadi_stacked.shape[0])
371+
(model.len_rhs, pS_casadi_stacked.shape[0])
362372
)
363373
if model.len_alg == 0:
364374
S_rhs = (df_dx @ S_x_mat + df_dp).reshape((-1, 1))
365375
else:
366376
df_dz = casadi.jacobian(func, y_alg)
367377
S_z_mat = S_z.reshape(
368-
(model.len_alg, p_casadi_stacked.shape[0])
378+
(model.len_alg, pS_casadi_stacked.shape[0])
369379
)
370380
S_rhs = (df_dx @ S_x_mat + df_dz @ S_z_mat + df_dp).reshape(
371381
(-1, 1)
@@ -376,34 +386,34 @@ def jacp(*args, **kwargs):
376386
"Creating sensitivity equations for algebraic using CasADi"
377387
)
378388
dg_dz = casadi.jacobian(func, y_alg)
379-
dg_dp = casadi.jacobian(func, p_casadi_stacked)
389+
dg_dp = casadi.jacobian(func, pS_casadi_stacked)
380390
S_z_mat = S_z.reshape(
381-
(model.len_alg, p_casadi_stacked.shape[0])
391+
(model.len_alg, pS_casadi_stacked.shape[0])
382392
)
383393
if model.len_rhs == 0:
384394
S_alg = (dg_dz @ S_z_mat + dg_dp).reshape((-1, 1))
385395
else:
386396
dg_dx = casadi.jacobian(func, y_diff)
387397
S_x_mat = S_x.reshape(
388-
(model.len_rhs, p_casadi_stacked.shape[0])
398+
(model.len_rhs, pS_casadi_stacked.shape[0])
389399
)
390400
S_alg = (dg_dx @ S_x_mat + dg_dz @ S_z_mat + dg_dp).reshape(
391401
(-1, 1)
392402
)
393403
func = casadi.vertcat(func, S_alg)
394404
elif name == "initial_conditions":
395405
if model.len_rhs == 0 or model.len_alg == 0:
396-
S_0 = casadi.jacobian(func, p_casadi_stacked).reshape(
406+
S_0 = casadi.jacobian(func, pS_casadi_stacked).reshape(
397407
(-1, 1)
398408
)
399409
func = casadi.vertcat(func, S_0)
400410
else:
401411
x0 = func[: model.len_rhs]
402412
z0 = func[model.len_rhs :]
403-
Sx_0 = casadi.jacobian(x0, p_casadi_stacked).reshape(
413+
Sx_0 = casadi.jacobian(x0, pS_casadi_stacked).reshape(
404414
(-1, 1)
405415
)
406-
Sz_0 = casadi.jacobian(z0, p_casadi_stacked).reshape(
416+
Sz_0 = casadi.jacobian(z0, pS_casadi_stacked).reshape(
407417
(-1, 1)
408418
)
409419
func = casadi.vertcat(x0, Sx_0, z0, Sz_0)
@@ -416,7 +426,7 @@ def jacp(*args, **kwargs):
416426
else:
417427
jac = None
418428

419-
if calculate_sensitivites:
429+
if calculate_sensitivites and self.sensitivity != "explicit forward":
420430
report((
421431
f"Calculating sensitivities for {name} with respect "
422432
f"to parameters {calculate_sensitivites} using CasADi"
@@ -529,12 +539,11 @@ def jacp(*args, **kwargs):
529539
init_eval = InitialConditions(initial_conditions, model)
530540

531541
if self.sensitivity == "explicit forward":
532-
init_eval.y_dummy = np.zeros(
533-
(
534-
model.len_rhs_and_alg * (np.vstack(list(inputs.values())).size + 1),
535-
1,
536-
)
542+
y0_total_size = (
543+
model.len_rhs + model.len_rhs_sens
544+
+ model.len_alg + model.len_alg_sens
537545
)
546+
init_eval.y_dummy = np.zeros((y0_total_size, 1))
538547
else:
539548
init_eval.y_dummy = np.zeros((model.len_rhs_and_alg, 1))
540549

@@ -546,6 +555,7 @@ def jacp(*args, **kwargs):
546555

547556
# Calculate initial conditions
548557
model.y0 = init_eval(inputs)
558+
print('YYYYY', model.y0)
549559

550560
casadi_terminate_events = []
551561
terminate_events_eval = []
@@ -710,6 +720,7 @@ def _set_initial_conditions(self, model, inputs, update_rhs):
710720
model.y0 = casadi.Function("y0", [symbolic_inputs], [y0])
711721
else:
712722
model.y0 = y0
723+
print('ASDF', model.y0)
713724

714725
def calculate_consistent_state(self, model, time=0, inputs=None):
715726
"""
@@ -736,13 +747,15 @@ def calculate_consistent_state(self, model, time=0, inputs=None):
736747
if self.root_method is None:
737748
return model.y0
738749
try:
739-
root_sol = self.root_method._integrate(model, [time], inputs)
750+
root_sol = self.root_method._integrate(model, np.array([time]), inputs)
740751
except pybamm.SolverError as e:
741752
raise pybamm.SolverError(
742753
"Could not find consistent states: {}".format(e.args[0])
743754
)
744755
pybamm.logger.debug("Found consistent states")
745-
y0 = root_sol.all_ys[0]
756+
757+
# use all_ys_and_sens in case we are solving the full sensitivity equations
758+
y0 = root_sol.all_ys_and_sens[0]
746759
if isinstance(y0, np.ndarray):
747760
y0 = y0.flatten()
748761
return y0
@@ -1428,7 +1441,7 @@ def __call__(self, t, y, inputs):
14281441
self.name, self.model.name, t * self.timescale
14291442
)
14301443
)
1431-
if self.name in ["RHS", "algebraic", "residuals"]:
1444+
if self.name in ["RHS", "algebraic", "residuals", "event"]:
14321445

14331446
return self.function(t, y, inputs).flatten()
14341447
else:
@@ -1437,7 +1450,7 @@ def __call__(self, t, y, inputs):
14371450
def function(self, t, y, inputs):
14381451
if self.form == "casadi":
14391452
states_eval = self._function(t, y, inputs)
1440-
if self.name in ["rhs", "algebraic", "residuals", "event"]:
1453+
if self.name in ["RHS", "algebraic", "residuals", "event"]:
14411454
return states_eval.full()
14421455
else:
14431456
# keep jacobians sparse

pybamm/solvers/casadi_algebraic_solver.py

+5
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,14 @@ def _integrate(self, model, t_eval, inputs_dict=None):
7676
inputs = casadi.vertcat(*[v for v in inputs_dict.values()])
7777

7878
y0 = model.y0
79+
print('algebraic', y0)
7980

8081
# If y0 already satisfies the tolerance for all t then keep it
8182
if self.sensitivity != "casadi" and all(
8283
np.all(abs(model.casadi_algebraic(t, y0, inputs).full()) < self.tol)
8384
for t in t_eval
8485
):
86+
print('keeping soln', y0.full())
8587
pybamm.logger.debug("Keeping same solution at all times")
8688
return pybamm.Solution(
8789
t_eval, y0, model, inputs_dict, termination="success"
@@ -92,14 +94,17 @@ def _integrate(self, model, t_eval, inputs_dict=None):
9294
# equations will be equal to the initial condition provided. This allows this
9395
# solver to be used for initialising the DAE solvers
9496
if model.rhs == {}:
97+
print('no rhs')
9598
len_rhs = 0
9699
y0_diff = casadi.DM()
97100
y0_alg = y0
98101
else:
99102
# Check y0 to see if it includes sensitivities
100103
if model.len_rhs_and_alg == y0.shape[0]:
104+
print('doesnt include sens')
101105
len_rhs = model.len_rhs
102106
else:
107+
print('includes sens', inputs.shape[0])
103108
len_rhs = model.len_rhs * (inputs.shape[0] + 1)
104109
y0_diff = y0[:len_rhs]
105110
y0_alg = y0[len_rhs:]

pybamm/solvers/idaklu_solver.py

+3
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,12 @@ def _integrate(self, model, t_eval, inputs_dict=None):
188188
atol = self._atol
189189

190190
y0 = model.y0
191+
print('idaklu, y0', y0)
191192
if isinstance(y0, casadi.DM):
192193
y0 = y0.full().flatten()
193194

195+
print('idaklu, y0', y0)
196+
194197
rtol = self._rtol
195198
atol = self._check_atol_type(atol, y0.size)
196199

0 commit comments

Comments
 (0)