Skip to content

Commit f5699c4

Browse files
committed
#1477 sorting out processed variable
1 parent ac94921 commit f5699c4

File tree

4 files changed

+94
-112
lines changed

4 files changed

+94
-112
lines changed

pybamm/solvers/base_solver.py

+9-25
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,6 @@ class BaseSolver(object):
3535
The tolerance for the initial-condition solver (default is 1e-6).
3636
extrap_tol : float, optional
3737
The tolerance to assert whether extrapolation occurs or not. Default is 0.
38-
sensitivity : str, optional
39-
Whether (and how) to calculate sensitivities when solving. Options are:
40-
- None (default): the individual solver is responsible for
41-
calculating the sensitivity wrt these parameters, and providing the result in
42-
the solution instance returned. At the moment this is only implemented for the
43-
IDAKLU solver.\
44-
- "explicit forward": explicitly formulate the sensitivity equations for
45-
the chosen input parameters. . At the moment this is only
46-
implemented using convert_to_format = 'casadi'. \
47-
- see individual solvers for other options
4838
"""
4939

5040
def __init__(
@@ -231,6 +221,8 @@ def set_up(self, model, inputs=None, t_eval=None,
231221
else:
232222
calculate_sensitivites = []
233223

224+
self.calculate_sensitivites = calculate_sensitivites
225+
234226
calculate_sensitivites_explicit = False
235227
if calculate_sensitivites and not isinstance(self, pybamm.IDAKLUSolver):
236228
calculate_sensitivites_explicit = True
@@ -360,12 +352,13 @@ def jacp(*args, **kwargs):
360352
# Add sensitivity vectors to the rhs and algebraic equations
361353
jacp = None
362354
if calculate_sensitivites_explicit:
355+
print('CASADI EXPLICIT', name, model.len_rhs)
363356
# The formulation is as per Park, S., Kato, D., Gima, Z., Klein, R.,
364357
# & Moura, S. (2018). Optimal experimental design for
365358
# parameterization of an electrochemical lithium-ion battery model.
366359
# Journal of The Electrochemical Society, 165(7), A1309.". See #1100
367360
# for details
368-
if name == "rhs" and model.len_rhs > 0:
361+
if name == "RHS" and model.len_rhs > 0:
369362
report(
370363
"Creating explicit forward sensitivity equations for rhs using CasADi")
371364
df_dx = casadi.jacobian(func, y_diff)
@@ -621,7 +614,7 @@ def jacp(*args, **kwargs):
621614

622615
# if we have changed the equations to include the explicit sensitivity
623616
# equations, then we also need to update the mass matrix
624-
if self.sensitivity == "explicit forward":
617+
if calculate_sensitivites_explicit:
625618
n_inputs = len(calculate_sensitivites)
626619
model.mass_matrix_inv = pybamm.Matrix(
627620
block_diag(
@@ -693,27 +686,21 @@ def _set_initial_conditions(self, model, inputs, update_rhs):
693686
Whether to update the rhs. True for 'solve', False for 'step'.
694687
695688
"""
696-
# Make inputs symbolic if calculating sensitivities with casadi
697-
if self.sensitivity == "casadi":
698-
symbolic_inputs = casadi.MX.sym(
699-
"inputs", casadi.vertcat(*inputs.values()).shape[0]
700-
)
701-
else:
702-
symbolic_inputs = inputs
689+
703690
if self.algebraic_solver is True:
704691
# Don't update model.y0
705692
return None
706693
elif len(model.algebraic) == 0:
707694
if update_rhs is True:
708695
# Recalculate initial conditions for the rhs equations
709-
y0 = model.init_eval(symbolic_inputs)
696+
y0 = model.init_eval(inputs)
710697
else:
711698
# Don't update model.y0
712699
return None
713700
else:
714701
if update_rhs is True:
715702
# Recalculate initial conditions for the rhs equations
716-
y0_from_inputs = model.init_eval(symbolic_inputs)
703+
y0_from_inputs = model.init_eval(inputs)
717704
# Reuse old solution for algebraic equations
718705
y0_from_model = model.y0
719706
len_rhs = model.len_rhs
@@ -726,10 +713,7 @@ def _set_initial_conditions(self, model, inputs, update_rhs):
726713
)
727714
y0 = self.calculate_consistent_state(model, 0, inputs)
728715
# Make y0 a function of inputs if doing symbolic with casadi
729-
if self.sensitivity == "casadi":
730-
model.y0 = casadi.Function("y0", [symbolic_inputs], [y0])
731-
else:
732-
model.y0 = y0
716+
model.y0 = y0
733717

734718
def calculate_consistent_state(self, model, time=0, inputs=None):
735719
"""

pybamm/solvers/casadi_solver.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
125125

126126

127127
# are we solving explicit forward equations?
128-
explicit_sensitivities = self.sensitivity == 'explicit forward'
128+
explicit_sensitivities = bool(self.calculate_sensitivites)
129129

130130
# Record whether there are any symbolic inputs
131131
inputs_dict = inputs_dict or {}
@@ -603,7 +603,7 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
603603
pybamm.logger.debug("Running CasADi integrator")
604604

605605
# are we solving explicit forward equations?
606-
explicit_sensitivities = self.sensitivity == 'explicit forward'
606+
explicit_sensitivities = bool(self.calculate_sensitivites)
607607

608608
if use_grid is True:
609609
t_eval_shifted = t_eval - t_eval[0]
@@ -613,12 +613,6 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
613613
integrator = self.integrators[model]["no grid"]
614614

615615
symbolic_inputs = casadi.MX.sym("inputs", inputs.shape[0])
616-
# If doing sensitivity with casadi, evaluate with symbolic inputs
617-
# Otherwise, evaluate with actual inputs
618-
if self.sensitivity == "casadi":
619-
inputs_eval = symbolic_inputs
620-
else:
621-
inputs_eval = inputs
622616

623617
len_rhs = model.concatenated_rhs.size
624618

@@ -656,7 +650,7 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
656650
for i in range(len(t_eval) - 1):
657651
t_min = t_eval[i]
658652
t_max = t_eval[i + 1]
659-
inputs_with_tlims = casadi.vertcat(inputs_eval, t_min, t_max)
653+
inputs_with_tlims = casadi.vertcat(inputs, t_min, t_max)
660654
timer = pybamm.Timer()
661655
casadi_sol = integrator(
662656
x0=x, z0=z, p=inputs_with_tlims, **self.extra_options_call

pybamm/solvers/processed_variable.py

+32-30
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
3939

4040
self.all_ts = solution.all_ts
4141
self.all_ys = solution.all_ys
42+
self.all_inputs = solution.all_inputs
4243
self.all_inputs_casadi = solution.all_inputs_casadi
4344

4445
self.mesh = base_variables[0].mesh
@@ -51,8 +52,8 @@ def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
5152
self.u_sol = solution.y
5253

5354
# Sensitivity starts off uninitialized, only set when called
54-
self._sensitivity = None
55-
self.all_sensitivities = solution.all_sensitivities
55+
self._sensitivities = None
56+
self.solution_sensitivities = solution.sensitivities
5657

5758
# Set timescale
5859
self.timescale = solution.timescale_eval
@@ -488,52 +489,44 @@ def data(self):
488489
"""Same as entries, but different name"""
489490
return self.entries
490491

491-
492-
class Interpolant0D:
493-
def __init__(self, entries):
494-
self.entries = entries
495-
496-
def __call__(self, t):
497-
return self.entries
498-
499492
@property
500-
def sensitivity(self):
493+
def sensitivities(self):
501494
"""
502-
Returns a dictionary of sensitivity for each input parameter.
495+
Returns a dictionary of sensitivities for each input parameter.
503496
The keys are the input parameters, and the value is a matrix of size
504497
(n_x * n_t, n_p), where n_x is the number of states, n_t is the number of time
505498
points, and n_p is the size of the input parameter
506499
"""
507-
# No sensitivity if there are no inputs
508-
if len(self.inputs) == 0:
500+
# No sensitivities if there are no inputs
501+
if len(self.all_inputs[0]) == 0:
509502
return {}
510-
# Otherwise initialise and return sensitivity
511-
if self._sensitivity is None:
512-
if self.solution_sensitivity != {}:
503+
# Otherwise initialise and return sensitivities
504+
if self._sensitivities is None:
505+
if self.solution_sensitivities != {}:
513506
self.initialise_sensitivity_explicit_forward()
514507
else:
515508
raise ValueError(
516-
"Cannot compute sensitivities. The 'sensitivity' argument of the "
517-
"solver should be changed from 'None' to allow sensitivity "
509+
"Cannot compute sensitivities. The 'sensitivities' argument of the "
510+
"solver.solve should be changed from 'None' to allow sensitivities "
518511
"calculations. Check solver documentation for details."
519512
)
520-
return self._sensitivity
513+
return self._sensitivities
521514

522515
def initialise_sensitivity_explicit_forward(self):
523516
"Set up the sensitivity dictionary"
524-
inputs_stacked = casadi.vertcat(*[p for p in self.inputs.values()])
517+
inputs_stacked = self.all_inputs_casadi[0]
525518

526519
# Set up symbolic variables
527520
t_casadi = casadi.MX.sym("t")
528521
y_casadi = casadi.MX.sym("y", self.u_sol.shape[0])
529522
p_casadi = {
530523
name: casadi.MX.sym(name, value.shape[0])
531-
for name, value in self.inputs.items()
524+
for name, value in self.all_inputs[0].items()
532525
}
533526
p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()])
534527

535528
# Convert variable to casadi format for differentiating
536-
var_casadi = self.base_variable.to_casadi(t_casadi, y_casadi, inputs=p_casadi)
529+
var_casadi = self.base_variables[0].to_casadi(t_casadi, y_casadi, inputs=p_casadi)
537530
dvar_dy = casadi.jacobian(var_casadi, y_casadi)
538531
dvar_dp = casadi.jacobian(var_casadi, p_casadi_stacked)
539532

@@ -544,8 +537,8 @@ def initialise_sensitivity_explicit_forward(self):
544537
dvar_dp_func = casadi.Function(
545538
"dvar_dp", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dp]
546539
)
547-
for idx in range(len(self.t_sol)):
548-
t = self.t_sol[idx]
540+
for idx in range(len(self.all_ts[0])):
541+
t = self.all_ts[0][idx]
549542
u = self.u_sol[:, idx]
550543
inp = inputs_stacked[:, idx]
551544
next_dvar_dy_eval = dvar_dy_func(t, u, inp)
@@ -558,20 +551,29 @@ def initialise_sensitivity_explicit_forward(self):
558551
dvar_dp_eval = casadi.vertcat(dvar_dp_eval, next_dvar_dp_eval)
559552

560553
# Compute sensitivity
561-
dy_dp = self.solution_sensitivity["all"]
554+
dy_dp = self.solution_sensitivities["all"]
562555
S_var = dvar_dy_eval @ dy_dp + dvar_dp_eval
563556

564-
sensitivity = {"all": S_var}
557+
sensitivities = {"all": S_var}
565558

566559
# Add the individual sensitivity
567560
start = 0
568-
for name, inp in self.inputs.items():
561+
for name, inp in self.all_inputs[0].items():
569562
end = start + inp.shape[0]
570-
sensitivity[name] = S_var[:, start:end]
563+
sensitivities[name] = S_var[:, start:end]
571564
start = end
572565

573566
# Save attribute
574-
self._sensitivity = sensitivity
567+
self._sensitivities = sensitivities
568+
569+
570+
class Interpolant0D:
571+
def __init__(self, entries):
572+
self.entries = entries
573+
574+
def __call__(self, t):
575+
return self.entries
576+
575577

576578
class Interpolant1D:
577579
def __init__(self, pts_for_interp, entries_for_interp):

0 commit comments

Comments
 (0)