Skip to content

Commit 6e91335

Browse files
committed
#1477 unit tests pass
1 parent d9ff546 commit 6e91335

14 files changed

+127
-272
lines changed

pybamm/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def version(formatted=False):
209209
#
210210
from .solvers.solution import Solution
211211
from .solvers.processed_variable import ProcessedVariable
212+
from .solvers.processed_symbolic_variable import ProcessedSymbolicVariable
212213
from .solvers.base_solver import BaseSolver
213214
from .solvers.dummy_solver import DummySolver
214215
from .solvers.algebraic_solver import AlgebraicSolver

pybamm/solvers/base_solver.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from scipy.sparse import block_diag
1212
import multiprocessing as mp
1313
import warnings
14-
import numbers
1514

1615

1716
class BaseSolver(object):
@@ -228,6 +227,13 @@ def set_up(self, model, inputs=None, t_eval=None,
228227
if calculate_sensitivites and not isinstance(self, pybamm.IDAKLUSolver):
229228
calculate_sensitivites_explicit = True
230229

230+
if calculate_sensitivites_explicit and model.convert_to_format != 'casadi':
231+
raise NotImplementedError(
232+
"Sensitivities only supported for:\n"
233+
" - model.convert_to_format = 'casadi'\n"
234+
" - IDAKLUSolver (any convert_to_format)"
235+
)
236+
231237
# save sensitivity parameters so we can identify them later on
232238
# (FYI: this is used in the Solution class)
233239
model.calculate_sensitivities = calculate_sensitivites
@@ -288,8 +294,8 @@ def report(string):
288294
jacp = None
289295
if calculate_sensitivites_explicit:
290296
raise NotImplementedError(
291-
"sensitivities using convert_to_format = 'jax' "
292-
"only implemented for IDAKLUSolver"
297+
"explicit sensitivity equations not supported for "
298+
"convert_to_format='jax'"
293299
)
294300
elif calculate_sensitivites:
295301
report((
@@ -310,11 +316,12 @@ def report(string):
310316
elif model.convert_to_format != "casadi":
311317
# Process with pybamm functions, optionally converting
312318
# to python evaluator
313-
if calculate_sensitivites:
319+
if calculate_sensitivites_explicit:
314320
raise NotImplementedError(
315-
"sensitivities only implemented with "
316-
"convert_to_format = 'casadi' or convert_to_format = 'jax'"
321+
"explicit sensitivity equations not supported for "
322+
"convert_to_format='{}'".format(model.convert_to_format)
317323
)
324+
elif calculate_sensitivites:
318325
report((
319326
f"Calculating sensitivities for {name} with respect "
320327
f"to parameters {calculate_sensitivites}"
@@ -367,7 +374,9 @@ def jacp(*args, **kwargs):
367374
# for details
368375
if name == "RHS" and model.len_rhs > 0:
369376
report(
370-
"Creating explicit forward sensitivity equations for rhs using CasADi")
377+
"Creating explicit forward sensitivity equations "
378+
"for rhs using CasADi"
379+
)
371380
df_dx = casadi.jacobian(func, y_diff)
372381
df_dp = casadi.jacobian(func, pS_casadi_stacked)
373382
S_x_mat = S_x.reshape(
@@ -386,7 +395,8 @@ def jacp(*args, **kwargs):
386395
func = casadi.vertcat(func, S_rhs)
387396
if name == "algebraic" and model.len_alg > 0:
388397
report(
389-
"Creating explicit forward sensitivity equations for algebraic using CasADi"
398+
"Creating explicit forward sensitivity equations "
399+
"for algebraic using CasADi"
390400
)
391401
dg_dz = casadi.jacobian(func, y_alg)
392402
dg_dp = casadi.jacobian(func, pS_casadi_stacked)
@@ -812,6 +822,7 @@ def solve(
812822
813823
"""
814824
pybamm.logger.info("Start solving {} with {}".format(model.name, self.name))
825+
self.calculate_sensitivites = calculate_sensitivities
815826

816827
# Make sure model isn't empty
817828
if len(model.rhs) == 0 and len(model.algebraic) == 0:
@@ -1401,7 +1412,7 @@ def _set_up_ext_and_inputs(self, model, external_variables, inputs):
14011412
name = input_param.name
14021413
if name not in inputs:
14031414
# Don't allow symbolic inputs if using `sensitivity`
1404-
if self.sensitivity == "explicit forward":
1415+
if self.calculate_sensitivites:
14051416
raise pybamm.SolverError(
14061417
"Cannot have symbolic inputs if explicitly solving forward"
14071418
"sensitivity equations"

pybamm/solvers/casadi_algebraic_solver.py

+33-40
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ def __init__(self, tol=1e-6, extra_options=None):
3232
self.extra_options = extra_options or {}
3333
pybamm.citations.register("Andersson2019")
3434

35-
self.rootfinders = {}
36-
self.y_sols = {}
37-
3835
@property
3936
def tol(self):
4037
return self._tol
@@ -102,6 +99,14 @@ def _integrate(self, model, t_eval, inputs_dict=None):
10299

103100
y_alg = None
104101

102+
# Set up
103+
t_sym = casadi.MX.sym("t")
104+
y_alg_sym = casadi.MX.sym("y_alg", y0_alg.shape[0])
105+
y_sym = casadi.vertcat(y0_diff, y_alg_sym)
106+
107+
t_and_inputs_sym = casadi.vertcat(t_sym, symbolic_inputs)
108+
alg = model.casadi_algebraic(t_sym, y_sym, inputs)
109+
105110
# Check interpolant extrapolation
106111
if model.interpolant_extrapolation_events_eval:
107112
extrap_event = [
@@ -116,7 +121,9 @@ def _integrate(self, model, t_eval, inputs_dict=None):
116121
event.event_type
117122
== pybamm.EventType.INTERPOLANT_EXTRAPOLATION
118123
and (
119-
event.expression.evaluate(0, y0.full(), inputs=inputs)
124+
event.expression.evaluate(
125+
0, y0.full(), inputs=inputs_dict
126+
)
120127
< self.extrap_tol
121128
)
122129
):
@@ -129,40 +136,26 @@ def _integrate(self, model, t_eval, inputs_dict=None):
129136
"outside these bounds.".format(extrap_event_names)
130137
)
131138

132-
if model in self.rootfinders:
133-
roots = self.rootfinders[model]
134-
else:
135-
# Set up
136-
t_sym = casadi.MX.sym("t")
137-
y0_diff_sym = casadi.MX.sym("y0_diff", y0_diff.shape[0])
138-
y_alg_sym = casadi.MX.sym("y_alg", y0_alg.shape[0])
139-
y_sym = casadi.vertcat(y0_diff_sym, y_alg_sym)
140-
141-
t_y0diff_inputs_sym = casadi.vertcat(t_sym, y0_diff_sym, symbolic_inputs)
142-
alg = model.casadi_algebraic(t_sym, y_sym, symbolic_inputs)
143-
144-
# Set constraints vector in the casadi format
145-
# Constrain the unknowns. 0 (default): no constraint on ui, 1: ui >= 0.0,
146-
# -1: ui <= 0.0, 2: ui > 0.0, -2: ui < 0.0.
147-
constraints = np.zeros_like(model.bounds[0], dtype=int)
148-
# If the lower bound is positive then the variable must always be positive
149-
constraints[model.bounds[0] >= 0] = 1
150-
# If the upper bound is negative then the variable must always be negative
151-
constraints[model.bounds[1] <= 0] = -1
152-
153-
# Set up rootfinder
154-
roots = casadi.rootfinder(
155-
"roots",
156-
"newton",
157-
dict(x=y_alg_sym, p=t_y0diff_inputs_sym, g=alg),
158-
{
159-
**self.extra_options,
160-
"abstol": self.tol,
161-
"constraints": list(constraints[len_rhs:]),
162-
},
163-
)
164-
165-
self.rootfinders[model] = roots
139+
# Set constraints vector in the casadi format
140+
# Constrain the unknowns. 0 (default): no constraint on ui, 1: ui >= 0.0,
141+
# -1: ui <= 0.0, 2: ui > 0.0, -2: ui < 0.0.
142+
constraints = np.zeros_like(model.bounds[0], dtype=int)
143+
# If the lower bound is positive then the variable must always be positive
144+
constraints[model.bounds[0] >= 0] = 1
145+
# If the upper bound is negative then the variable must always be negative
146+
constraints[model.bounds[1] <= 0] = -1
147+
148+
# Set up rootfinder
149+
roots = casadi.rootfinder(
150+
"roots",
151+
"newton",
152+
dict(x=y_alg_sym, p=t_and_inputs_sym, g=alg),
153+
{
154+
**self.extra_options,
155+
"abstol": self.tol,
156+
"constraints": list(constraints[len_rhs:]),
157+
},
158+
)
166159

167160
timer = pybamm.Timer()
168161
integration_time = 0
@@ -182,11 +175,11 @@ def _integrate(self, model, t_eval, inputs_dict=None):
182175
y_alg = casadi.horzcat(y_alg, y0_alg)
183176
# Otherwise calculate new y_sol
184177
else:
185-
t_y0_diff_inputs = casadi.vertcat(t, y0_diff, symbolic_inputs)
178+
t_eval_inputs_sym = casadi.vertcat(t, symbolic_inputs)
186179
# Solve
187180
try:
188181
timer.reset()
189-
y_alg_sol = roots(y0_alg, t_y0_diff_inputs)
182+
y_alg_sol = roots(y0_alg, t_eval_inputs_sym)
190183
integration_time += timer.time()
191184
success = True
192185
message = None

pybamm/solvers/casadi_solver.py

-3
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def _integrate(self, model, t_eval, inputs_dict=None):
123123
Any external variables or input parameters to pass to the model when solving
124124
"""
125125

126-
127126
# are we solving explicit forward equations?
128127
explicit_sensitivities = bool(self.calculate_sensitivites)
129128

@@ -612,8 +611,6 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
612611
else:
613612
integrator = self.integrators[model]["no grid"]
614613

615-
symbolic_inputs = casadi.MX.sym("inputs", inputs.shape[0])
616-
617614
len_rhs = model.concatenated_rhs.size
618615

619616
# Check y0 to see if it includes sensitivities

pybamm/solvers/idaklu_solver.py

+1-16
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,6 @@ class IDAKLUSolver(pybamm.BaseSolver):
3838
The tolerance for the initial-condition solver (default is 1e-6).
3939
extrap_tol : float, optional
4040
The tolerance to assert whether extrapolation occurs or not (default is 0).
41-
sensitivity : str, optional
42-
Whether (and how) to calculate sensitivities when solving. Options are:
43-
- "explicit forward": explicitly formulate the sensitivity equations. \
44-
The formulation is as per "Park, S., Kato, D., Gima, Z., \
45-
Klein, R., & Moura, S. (2018). Optimal experimental design for parameterization\
46-
of an electrochemical lithium-ion battery model. Journal of The Electrochemical\
47-
Society, 165(7), A1309.". See #1100 for details \
48-
4941
"""
5042

5143
def __init__(
@@ -56,17 +48,11 @@ def __init__(
5648
root_tol=1e-6,
5749
extrap_tol=0,
5850
max_steps="deprecated",
59-
sensitivity="idas"
6051
):
6152

6253
if idaklu_spec is None:
6354
raise ImportError("KLU is not installed")
6455

65-
if sensitivity == "explicit forward":
66-
raise NotImplementedError(
67-
"Cannot use explicit forward equations with IDAKLUSolver"
68-
)
69-
7056
super().__init__(
7157
"ida",
7258
rtol,
@@ -75,7 +61,6 @@ def __init__(
7561
root_tol,
7662
extrap_tol,
7763
max_steps,
78-
sensitivity=sensitivity,
7964
)
8065
self.name = "IDA KLU solver"
8166

@@ -339,7 +324,7 @@ def sensfn(resvalS, t, y, yp, yS, ypS):
339324
name: sol.yS[i].transpose() for i, name in enumerate(sens0.keys())
340325
}
341326
else:
342-
yS_out = None
327+
yS_out = False
343328
if sol.flag in [0, 2]:
344329
# 0 = solved for all t_eval
345330
if sol.flag == 0:

pybamm/solvers/processed_variable.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,9 @@ def initialise_sensitivity_explicit_forward(self):
524524
p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()])
525525

526526
# Convert variable to casadi format for differentiating
527-
var_casadi = self.base_variables[0].to_casadi(t_casadi, y_casadi, inputs=p_casadi)
527+
var_casadi = self.base_variables[0].to_casadi(
528+
t_casadi, y_casadi, inputs=p_casadi
529+
)
528530
dvar_dy = casadi.jacobian(var_casadi, y_casadi)
529531
dvar_dp = casadi.jacobian(var_casadi, p_casadi_stacked)
530532

pybamm/solvers/scipy_solver.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,6 @@ class ScipySolver(pybamm.BaseSolver):
2525
Any options to pass to the solver.
2626
Please consult `SciPy documentation <https://tinyurl.com/yafgqg9y>`_ for
2727
details.
28-
sensitivity : str, optional
29-
Whether (and how) to calculate sensitivities when solving. Options are:
30-
31-
- None: no sensitivities
32-
- "explicit forward": explicitly formulate the sensitivity equations. \
33-
See :class:`pybamm.BaseSolver`
3428
"""
3529

3630
def __init__(
@@ -40,14 +34,12 @@ def __init__(
4034
atol=1e-6,
4135
extrap_tol=0,
4236
extra_options=None,
43-
sensitivity=None,
4437
):
4538
super().__init__(
4639
method=method,
4740
rtol=rtol,
4841
atol=atol,
4942
extrap_tol=extrap_tol,
50-
sensitivity=sensitivity,
5143
)
5244
self.ode_solver = True
5345
self.extra_options = extra_options or {}
@@ -136,7 +128,8 @@ def event_fn(t, y):
136128
t_event = None
137129
y_event = np.array(None)
138130
sol = pybamm.Solution(
139-
sol.t, sol.y, model, inputs_dict, t_event, y_event, termination
131+
sol.t, sol.y, model, inputs_dict, t_event, y_event, termination,
132+
sensitivities=bool(self.calculate_sensitivites)
140133
)
141134
sol.integration_time = integration_time
142135
return sol

pybamm/solvers/solution.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
self._sensitivities = {}
8787
# if solution consists of explicit sensitivity equations, extract them
8888
if (
89-
sensitivities == True
89+
sensitivities is True
9090
and all_models[0] is not None
9191
and not isinstance(all_ys[0], casadi.Function)
9292
and all_models[0].len_rhs_and_alg != all_ys[0].shape[0]
@@ -97,7 +97,7 @@ def __init__(
9797
self._all_ys[0], self._sensitivities = \
9898
self._extract_explicit_sensitivities(
9999
all_models[0], all_ys[0], all_ts[0], self.all_inputs[0]
100-
)
100+
)
101101
elif isinstance(sensitivities, dict):
102102
self._sensitivities = sensitivities
103103
else:

tests/unit/test_solvers/test_base_solver.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,12 @@ def __init__(self):
163163
)
164164
self.convert_to_format = "casadi"
165165
self.bounds = (-np.inf * np.ones(4), np.inf * np.ones(4))
166-
<<<<<<< HEAD
167166
self.len_rhs = 1
168167
self.len_rhs_and_alg = 4
169168
self.interpolant_extrapolation_events_eval = []
170-
>>>>>>> develop
169+
171170
def rhs_eval(self, t, y, inputs):
171+
return y[0:1]
172172

173173
def algebraic_eval(self, t, y, inputs):
174174
return (y[1:] - vec[1:]) ** 2

0 commit comments

Comments
 (0)