Skip to content

Commit 03cc0eb

Browse files
#1100 reformat Solution syntax
1 parent 5b4f8e8 commit 03cc0eb

16 files changed

+113
-86
lines changed

pybamm/solvers/algebraic_solver.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def _integrate(self, model, t_eval, inputs=None):
5959
inputs : dict, optional
6060
Any input parameters to pass to the model when solving
6161
"""
62-
inputs = inputs or {}
62+
inputs_dict = inputs or {}
6363
if model.convert_to_format == "casadi":
64-
inputs = casadi.vertcat(*[x for x in inputs.values()])
64+
inputs = casadi.vertcat(*[x for x in inputs_dict.values()])
6565

6666
y0 = model.y0
6767
if isinstance(y0, casadi.DM):
@@ -210,4 +210,7 @@ def jac_norm(y):
210210
y_diff = np.r_[[y0_diff] * len(t_eval)].T
211211
y_sol = np.r_[y_diff, y_alg]
212212
# Return solution object (no events, so pass None to t_event, y_event)
213-
return pybamm.Solution(t_eval, y_sol, termination="success")
213+
return pybamm.Solution(
214+
t_eval, y_sol, termination="success", model=model, inputs=inputs_dict
215+
)
216+

pybamm/solvers/base_solver.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ class BaseSolver(object):
3434
sensitivity : str, optional
3535
Whether (and how) to calculate sensitivities when solving. Options are:
3636
37-
- "explicit forward": explicitly formulate the sensitivity equations.
38-
The formulation is as per "Park, S., Kato, D., Gima, Z.,
39-
Klein, R., & Moura, S. (2018). Optimal experimental design for parameterization
40-
of an electrochemical lithium-ion battery model. Journal of The Electrochemical
41-
Society, 165(7), A1309.". See #1100 for details
37+
- "explicit forward": explicitly formulate the sensitivity equations. \
38+
The formulation is as per "Park, S., Kato, D., Gima, Z., \
39+
Klein, R., & Moura, S. (2018). Optimal experimental design for parameterization\
40+
of an electrochemical lithium-ion battery model. Journal of The Electrochemical\
41+
Society, 165(7), A1309.". See #1100 for details \
4242
- see specific solvers for other options
4343
"""
4444

@@ -891,10 +891,6 @@ def step(
891891
solution.set_up_time = set_up_time
892892
solution.solve_time = timer.time()
893893

894-
# Add model and inputs to solution
895-
solution.model = model
896-
solution.inputs = ext_and_inputs
897-
898894
# Identify the event that caused termination
899895
termination = self.get_termination_reason(solution, model.events)
900896

pybamm/solvers/casadi_algebraic_solver.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class CasadiAlgebraicSolver(pybamm.BaseSolver):
2525
Whether (and how) to calculate sensitivities when solving. Options are:
2626
2727
- None: no sensitivities
28-
- "explicit forward": explicitly formulate the sensitivity equations.
28+
- "explicit forward": explicitly formulate the sensitivity equations. \
2929
See :class:`pybamm.BaseSolver`
3030
- "casadi": use casadi to differentiate through the rootfinding operator
3131
@@ -66,7 +66,7 @@ def _integrate(self, model, t_eval, inputs=None):
6666
# Record whether there are any symbolic inputs
6767
inputs_dict = inputs or {}
6868
# Create casadi objects for the root-finder
69-
inputs = casadi.vertcat(*[v for v in inputs.values()])
69+
inputs = casadi.vertcat(*[v for v in inputs_dict.values()])
7070

7171
# Create symbolic inputs
7272
symbolic_inputs = casadi.MX.sym("inputs", inputs.shape[0])

pybamm/solvers/casadi_solver.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ class CasadiSolver(pybamm.BaseSolver):
5858
Any options to pass to the CasADi integrator when calling the integrator.
5959
Please consult `CasADi documentation <https://tinyurl.com/y5rk76os>`_ for
6060
details.
61-
sensitivity : bool, optional
61+
sensitivity : str, optional
6262
Whether (and how) to calculate sensitivities when solving. Options are:
6363
6464
- None: no sensitivities
65-
- "explicit forward": explicitly formulate the sensitivity equations.
65+
- "explicit forward": explicitly formulate the sensitivity equations. \
6666
See :class:`pybamm.BaseSolver`
6767
- "casadi": use casadi to differentiate through the integrator
6868
"""

pybamm/solvers/dummy_solver.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,6 @@ def _integrate(self, model, t_eval, inputs=None):
3333
3434
"""
3535
y_sol = np.zeros((1, t_eval.size))
36-
return pybamm.Solution(t_eval, y_sol, termination="final time")
36+
return pybamm.Solution(
37+
t_eval, y_sol, termination="final time", model=model, inputs=inputs
38+
)

pybamm/solvers/idaklu_solver.py

+3
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def _integrate(self, model, t_eval, inputs=None):
154154
t_eval : numeric type
155155
The times at which to compute the solution
156156
"""
157+
inputs_dict = inputs
157158
if model.rhs_eval.form == "casadi":
158159
# stack inputs
159160
inputs = casadi.vertcat(*[x for x in inputs.values()])
@@ -272,6 +273,8 @@ def rootfn(t, y):
272273
t[-1],
273274
np.transpose(y_out[-1])[:, np.newaxis],
274275
termination,
276+
model=model,
277+
inputs=inputs_dict,
275278
)
276279
else:
277280
raise pybamm.SolverError(sol.message)

pybamm/solvers/jax_solver.py

+32-25
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,17 @@ class JaxSolver(pybamm.BaseSolver):
4545
for details.
4646
"""
4747

48-
def __init__(self, method='RK45', root_method=None,
49-
rtol=1e-6, atol=1e-6, extra_options=None):
48+
def __init__(
49+
self, method="RK45", root_method=None, rtol=1e-6, atol=1e-6, extra_options=None
50+
):
5051
# note: bdf solver itself calculates consistent initial conditions so can set
5152
# root_method to none, allow user to override this behavior
5253
super().__init__(method, rtol, atol, root_method=root_method)
53-
method_options = ['RK45', 'BDF']
54+
method_options = ["RK45", "BDF"]
5455
if method not in method_options:
55-
raise ValueError('method must be one of {}'.format(method_options))
56+
raise ValueError("method must be one of {}".format(method_options))
5657
self.ode_solver = False
57-
if method == 'RK45':
58+
if method == "RK45":
5859
self.ode_solver = True
5960
self.extra_options = extra_options or {}
6061
self.name = "JAX solver ({})".format(method)
@@ -80,8 +81,9 @@ def get_solve(self, model, t_eval):
8081
"""
8182
if model not in self._cached_solves:
8283
if model not in self.models_set_up:
83-
raise RuntimeError("Model is not set up for solving, run"
84-
"`solver.solve(model)` first")
84+
raise RuntimeError(
85+
"Model is not set up for solving, run" "`solver.solve(model)` first"
86+
)
8587

8688
self._cached_solves[model] = self.create_solve(model, t_eval)
8789

@@ -106,32 +108,35 @@ def create_solve(self, model, t_eval):
106108
107109
"""
108110
if model.convert_to_format != "jax":
109-
raise RuntimeError("Model must be converted to JAX to use this solver"
110-
" (i.e. `model.convert_to_format = 'jax')")
111+
raise RuntimeError(
112+
"Model must be converted to JAX to use this solver"
113+
" (i.e. `model.convert_to_format = 'jax')"
114+
)
111115

112116
if model.terminate_events_eval:
113-
raise RuntimeError("Terminate events not supported for this solver."
114-
" Model has the following events:"
115-
" {}.\nYou can remove events using `model.events = []`."
116-
" It might be useful to first solve the model using a"
117-
" different solver to obtain the time of the event, then"
118-
" re-solve using no events and a fixed"
119-
" end-time".format(model.events))
117+
raise RuntimeError(
118+
"Terminate events not supported for this solver."
119+
" Model has the following events:"
120+
" {}.\nYou can remove events using `model.events = []`."
121+
" It might be useful to first solve the model using a"
122+
" different solver to obtain the time of the event, then"
123+
" re-solve using no events and a fixed"
124+
" end-time".format(model.events)
125+
)
120126

121127
# Initial conditions, make sure they are an 0D array
122128
y0 = jnp.array(model.y0).reshape(-1)
123129
mass = None
124-
if self.method == 'BDF':
130+
if self.method == "BDF":
125131
mass = model.mass_matrix.entries.toarray()
126132

127133
def rhs_ode(y, t, inputs):
128-
return model.rhs_eval(t, y, inputs),
134+
return (model.rhs_eval(t, y, inputs),)
129135

130136
def rhs_dae(y, t, inputs):
131-
return jnp.concatenate([
132-
model.rhs_eval(t, y, inputs),
133-
model.algebraic_eval(t, y, inputs),
134-
])
137+
return jnp.concatenate(
138+
[model.rhs_eval(t, y, inputs), model.algebraic_eval(t, y, inputs)]
139+
)
135140

136141
def solve_model_rk45(inputs):
137142
y = odeint(
@@ -158,7 +163,7 @@ def solve_model_bdf(inputs):
158163
)
159164
return jnp.transpose(y)
160165

161-
if self.method == 'RK45':
166+
if self.method == "RK45":
162167
return jax.jit(solve_model_rk45)
163168
else:
164169
return jax.jit(solve_model_bdf)
@@ -194,5 +199,7 @@ def _integrate(self, model, t_eval, inputs=None):
194199
termination = "final time"
195200
t_event = None
196201
y_event = onp.array(None)
197-
return pybamm.Solution(t_eval, y,
198-
t_event, y_event, termination)
202+
return pybamm.Solution(
203+
t_eval, y, t_event, y_event, termination, model=model, inputs=inputs
204+
)
205+

pybamm/solvers/scikits_dae_solver.py

+3
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _integrate(self, model, t_eval, inputs=None):
8181
Any input parameters to pass to the model when solving
8282
8383
"""
84+
inputs_dict = inputs
8485
if model.convert_to_format == "casadi":
8586
inputs = casadi.vertcat(*[x for x in inputs.values()])
8687

@@ -150,6 +151,8 @@ def jacfn(t, y, ydot, residuals, cj, J):
150151
t_root,
151152
np.transpose(sol.roots.y),
152153
termination,
154+
model=model,
155+
inputs=inputs_dict,
153156
)
154157
else:
155158
raise pybamm.SolverError(sol.message)

pybamm/solvers/scikits_ode_solver.py

+3
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def _integrate(self, model, t_eval, inputs=None):
8080
Any input parameters to pass to the model when solving
8181
8282
"""
83+
inputs_dict = inputs
8384
if model.rhs_eval.form == "casadi":
8485
inputs = casadi.vertcat(*[x for x in inputs.values()])
8586

@@ -167,6 +168,8 @@ def jac_times_setupfn(t, y, fy, userdata):
167168
t_root,
168169
np.transpose(sol.roots.y),
169170
termination,
171+
model=model,
172+
inputs=inputs_dict,
170173
)
171174
else:
172175
raise pybamm.SolverError(sol.message)

pybamm/solvers/scipy_solver.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ class ScipySolver(pybamm.BaseSolver):
2323
Any options to pass to the solver.
2424
Please consult `SciPy documentation <https://tinyurl.com/yafgqg9y>`_ for
2525
details.
26-
sensitivity : bool, optional
27-
Whether to explicitly formulate and solve the forward sensitivity equations.
26+
sensitivity : str, optional
27+
Whether (and how) to calculate sensitivities when solving. Options are:
28+
29+
- None: no sensitivities
30+
- "explicit forward": explicitly formulate the sensitivity equations. \
2831
See :class:`pybamm.BaseSolver`
2932
"""
3033

pybamm/solvers/solution.py

+20-28
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def __init__(
5757
if isinstance(y, casadi.DM):
5858
y = y.full()
5959

60-
# if model or inputs are None, initialize empty, to be populated later
61-
self.inputs = inputs or pybamm.FuzzyDict()
62-
self._model = model or pybamm.BaseModel()
60+
# if inputs are None, initialize empty, to be populated later
61+
inputs = inputs or pybamm.FuzzyDict()
62+
self.set_inputs(inputs)
6363

6464
# If the model has been provided, split up y into solution and sensitivity
6565
# Don't do this if the sensitivity equations have not been computed (i.e. if
@@ -70,6 +70,7 @@ def __init__(
7070
model is None
7171
or isinstance(y, casadi.Function)
7272
or model.len_rhs_and_alg == y.shape[0]
73+
or model.len_rhs_and_alg == 0 # for the dummy solver
7374
):
7475
self._y = y
7576
self.sensitivity = {}
@@ -129,6 +130,8 @@ def __init__(
129130
start = end
130131
self.sensitivity = sensitivity
131132

133+
model = model or pybamm.BaseModel()
134+
self.set_model(model)
132135
self._t_event = t_event
133136
self._y_event = y_event
134137
self._termination = termination
@@ -163,39 +166,28 @@ def model(self):
163166
"Model used for solution"
164167
return self._model
165168

166-
@model.setter
167-
def model(self, value):
169+
def set_model(self, value):
168170
"Updates the model"
169-
assert isinstance(value, pybamm.BaseModel)
170171
self._model = value
171172

172173
@property
173174
def inputs(self):
174175
"Values of the inputs"
175176
return self._inputs
176177

177-
@inputs.setter
178-
def inputs(self, inputs):
178+
def set_inputs(self, inputs):
179179
"Updates the input values"
180-
# If there are symbolic inputs, just store them as given
181-
if any(isinstance(v, casadi.MX) for v in inputs.values()):
182-
self.has_symbolic_inputs = True
183-
self._inputs = inputs
184-
# Otherwise, make them the same size as the time vector
185-
else:
186-
self.has_symbolic_inputs = False
187-
self._inputs = {}
188-
for name, inp in inputs.items():
189-
# Convert number to vector of the right shape
190-
if isinstance(inp, numbers.Number):
191-
inp = inp * np.ones((1, len(self.t)))
192-
# Tile a vector
193-
else:
194-
if inp.ndim == 1:
195-
inp = np.tile(inp, (len(self.t), 1)).T
196-
else:
197-
inp = np.tile(inp, len(self.t))
198-
self._inputs[name] = inp
180+
self._inputs = {}
181+
for name, inp in inputs.items():
182+
# Convert number to vector of the right shape
183+
if isinstance(inp, numbers.Number):
184+
inp = inp * np.ones((1, len(self.t)))
185+
# Otherwise, tile a vector
186+
elif inp.ndim == 1:
187+
inp = np.tile(inp, (len(self.t), 1)).T
188+
elif inp.shape[1] != len(self.t):
189+
inp = np.tile(inp, len(self.t))
190+
self._inputs[name] = inp
199191

200192
@property
201193
def t_event(self):
@@ -434,6 +426,6 @@ def append(self, solution, start_index=1, create_sub_solutions=False):
434426
solution.termination,
435427
copy_this=solution,
436428
model=self.model,
437-
inputs=copy.copy(self.inputs),
429+
inputs=copy.copy(solution.inputs),
438430
)
439431
)

tests/unit/test_solvers/test_algebraic_solver.py

+3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class Model:
4444
timescale_eval = 1
4545
jac_algebraic_eval = None
4646
convert_to_format = "python"
47+
len_rhs_and_alg = 1
4748

4849
def algebraic_eval(self, t, y, inputs):
4950
return y + 2
@@ -66,6 +67,7 @@ class Model:
6667
timescale_eval = 1
6768
jac_algebraic_eval = None
6869
convert_to_format = "casadi"
70+
len_rhs_and_alg = 1
6971

7072
def algebraic_eval(self, t, y, inputs):
7173
# algebraic equation has no real root
@@ -95,6 +97,7 @@ class Model:
9597
rhs = {}
9698
timescale_eval = 1
9799
convert_to_format = "python"
100+
len_rhs_and_alg = 2
98101

99102
def algebraic_eval(self, t, y, inputs):
100103
return A @ y - b

tests/unit/test_solvers/test_base_solver.py

+3
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(self):
119119
)
120120
self.convert_to_format = "casadi"
121121
self.bounds = (np.array([-np.inf]), np.array([np.inf]))
122+
self.len_rhs_and_alg = 1
122123

123124
def rhs_eval(self, t, y, inputs):
124125
return np.array([])
@@ -154,6 +155,8 @@ def __init__(self):
154155
)
155156
self.convert_to_format = "casadi"
156157
self.bounds = (-np.inf * np.ones(4), np.inf * np.ones(4))
158+
self.len_rhs = 1
159+
self.len_rhs_and_alg = 4
157160

158161
def rhs_eval(self, t, y, inputs):
159162
return y[0:1]

tests/unit/test_solvers/test_scikits_solvers.py

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class Model:
9797
terminate_events_eval = []
9898
timescale_eval = 1
9999
convert_to_format = "python"
100+
len_rhs_and_alg = 2
100101

101102
def residuals_eval(self, t, y, ydot, inputs):
102103
return np.array(

0 commit comments

Comments
 (0)