Skip to content

Commit f97c952

Browse files
#1100 SODEs working with scipy
1 parent a59abb5 commit f97c952

File tree

5 files changed

+253
-28
lines changed

5 files changed

+253
-28
lines changed

pybamm/models/base_model.py

+13
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,19 @@ def timescale(self, value):
286286
"Set the timescale"
287287
self._timescale = value
288288

289+
@property
290+
def length_scales(self):
291+
"Length scales of model"
292+
return self._length_scale
293+
294+
@length_scales.setter
295+
def length_scales(self, values):
296+
"Set the length scale, converting any numbers to pybamm.Scalar"
297+
for domain, scale in values.items():
298+
if isinstance(scale, numbers.Number):
299+
values[domain] = pybamm.Scalar(scale)
300+
self._length_scale = values
301+
289302
@property
290303
def parameters(self):
291304
"Returns all the parameters in the model"

pybamm/solvers/base_solver.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,14 @@ def report(string):
290290
report(f"Creating sensitivity equations for rhs using CasADi")
291291
df_dx = casadi.jacobian(func, y_diff)
292292
df_dp = casadi.jacobian(func, p_casadi_stacked)
293+
S_x_mat = S_x.reshape(
294+
(model.len_rhs_and_alg, p_casadi_stacked.shape[0])
295+
)
293296
if model.len_alg == 0:
294-
S_rhs = df_dx @ S_x + df_dp
297+
S_rhs = (df_dx @ S_x_mat + df_dp).reshape((-1, 1))
295298
else:
296299
df_dz = casadi.jacobian(func, y_alg)
297-
S_rhs = df_dx @ S_x + df_dz @ S_z + df_dp
300+
S_rhs = df_dx @ S_x_mat + df_dz @ S_z + df_dp
298301
func = casadi.vertcat(func, S_rhs)
299302
elif name == "initial_conditions":
300303
if model.len_rhs == 0 or model.len_alg == 0:

pybamm/solvers/solution.py

+31-19
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ def __init__(
6767
if model is None or model.len_rhs_and_alg == y.shape[0]:
6868
self._y = y
6969
else:
70-
all_inputs_size = np.vstack(list(inputs.values())).size
70+
n_states = model.len_rhs_and_alg
71+
n_t = len(t)
72+
n_p = np.vstack(list(inputs.values())).size
7173
# Get the point where the algebraic equations start
72-
len_rhs_and_sens = all_inputs_size * model.len_rhs
74+
len_rhs_and_sens = (n_p + 1) * model.len_rhs
7375
# self._y gets the part of the solution vector that correspond to the
7476
# actual ODE/DAE solution
7577
self._y = np.vstack(
@@ -94,28 +96,35 @@ def __init__(
9496
# tn_x1_p0, tn_x1_p1, ..., tn_x1_pn
9597
# ...
9698
# tn_xn_p0, tn_xn_p1, ..., tn_xn_pn
99+
# 1. Extract the relevant parts of y
100+
# This makes a (n_states * n_p, n_t) matrix
97101
full_sens_matrix = np.vstack(
98102
[
99103
y[model.len_rhs : len_rhs_and_sens, :],
100104
y[len_rhs_and_sens + model.len_alg :, :],
101105
]
102-
).reshape(np.prod(self._y.shape), all_inputs_size, order="F")
106+
)
107+
# 2. Transpose into a (n_t, n_states * n_p) matrix
108+
full_sens_matrix = full_sens_matrix.T
109+
# 3. Reshape into a (n_t, n_p, n_states) matrix,
110+
# then tranpose n_p and n_states to get (n_t, n_states, n_p) matrix
111+
full_sens_matrix = full_sens_matrix.reshape(n_t, n_p, n_states).transpose(
112+
0, 2, 1
113+
)
114+
# 3. Stack time and space to get a (n_t * n_states, n_p) matrix
115+
full_sens_matrix = full_sens_matrix.reshape(n_t * n_states, n_p)
116+
117+
# Save the full sensitivity matrix
118+
103119
sensitivity = {"all": full_sens_matrix}
104-
# also save the sensitivity wrt each parameter
105-
start_rhs = model.len_rhs
106-
start_alg = len_rhs_and_sens + model.len_alg
107-
for i, (name, inp) in enumerate(inputs.items()):
108-
if isinstance(inp, numbers.Number):
109-
input_size = 1
110-
else:
111-
input_size = inp.shape[0]
112-
end_rhs = start_rhs + model.len_rhs * input_size
113-
end_alg = start_alg + model.len_alg * input_size
114-
sensitivity[name] = np.vstack(
115-
[y[start_rhs:end_rhs, :], y[start_alg:end_alg, :],]
116-
).reshape(-1, 1)
117-
start_rhs = end_rhs
118-
start_alg = end_alg
120+
# also save the sensitivity wrt each parameter (read the columns of the
121+
# sensitivity matrix)
122+
start = 0
123+
for i, (name, inp) in enumerate(self.inputs.items()):
124+
input_size = inp.shape[0]
125+
end = start + input_size
126+
sensitivity[name] = full_sens_matrix[:, start:end]
127+
start = end
119128
self.sensitivity = sensitivity
120129

121130
self._t_event = t_event
@@ -182,7 +191,10 @@ def inputs(self, inputs):
182191
inp = inp * np.ones((1, len(self.t)))
183192
# Tile a vector
184193
else:
185-
inp = np.tile(inp, len(self.t))
194+
if inp.ndim == 1:
195+
inp = np.tile(inp, (len(self.t), 1)).T
196+
else:
197+
inp = np.tile(inp, len(self.t))
186198
self._inputs[name] = inp
187199

188200
@property

tests/unit/test_models/test_base_model.py

+6
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ def test_boundary_conditions_set_get(self):
9494
with self.assertRaisesRegex(pybamm.ModelError, "boundary condition"):
9595
model.boundary_conditions = bad_bcs
9696

97+
def test_length_scales(self):
98+
model = pybamm.BaseModel()
99+
model.length_scales = {"a": 1.3}
100+
self.assertIsInstance(model.length_scales["a"], pybamm.Scalar)
101+
self.assertEqual(model.length_scales["a"].value, 1.3)
102+
97103
def test_variables_set_get(self):
98104
model = pybamm.BaseModel()
99105
variables = {"c": "alpha", "d": "beta"}

tests/unit/test_solvers/test_scipy_solver.py

+198-7
Original file line numberDiff line numberDiff line change
@@ -383,10 +383,81 @@ def test_solve_sensitivity_scalar_var_scalar_input(self):
383383
],
384384
)
385385

386+
# More complicated model
387+
# Create model
388+
model = pybamm.BaseModel()
389+
var = pybamm.Variable("var")
390+
p = pybamm.InputParameter("p")
391+
q = pybamm.InputParameter("q")
392+
r = pybamm.InputParameter("r")
393+
s = pybamm.InputParameter("s")
394+
model.rhs = {var: p * q}
395+
model.initial_conditions = {var: r}
396+
model.variables = {"var times s": var * s}
397+
398+
# Solve
399+
# Make sure that passing in extra options works
400+
solver = pybamm.ScipySolver(
401+
rtol=1e-10, atol=1e-10, solve_sensitivity_equations=True
402+
)
403+
t_eval = np.linspace(0, 1, 80)
404+
solution = solver.solve(
405+
model, t_eval, inputs={"p": 0.1, "q": 2, "r": -1, "s": 0.5}
406+
)
407+
np.testing.assert_allclose(solution.y[0], -1 + 0.2 * solution.t)
408+
np.testing.assert_allclose(
409+
solution.sensitivity["p"], (2 * solution.t)[:, np.newaxis],
410+
)
411+
np.testing.assert_allclose(
412+
solution.sensitivity["q"], (0.1 * solution.t)[:, np.newaxis],
413+
)
414+
np.testing.assert_allclose(solution.sensitivity["r"], 1)
415+
np.testing.assert_allclose(solution.sensitivity["s"], 0)
416+
np.testing.assert_allclose(
417+
solution.sensitivity["all"],
418+
np.hstack(
419+
[
420+
solution.sensitivity["p"],
421+
solution.sensitivity["q"],
422+
solution.sensitivity["r"],
423+
solution.sensitivity["s"],
424+
]
425+
),
426+
)
427+
np.testing.assert_allclose(
428+
solution["var times s"].data, 0.5 * (-1 + 0.2 * solution.t)
429+
)
430+
np.testing.assert_allclose(
431+
solution["var times s"].sensitivity["p"],
432+
0.5 * (2 * solution.t)[:, np.newaxis],
433+
)
434+
np.testing.assert_allclose(
435+
solution["var times s"].sensitivity["q"],
436+
0.5 * (0.1 * solution.t)[:, np.newaxis],
437+
)
438+
np.testing.assert_allclose(solution["var times s"].sensitivity["r"], 0.5)
439+
np.testing.assert_allclose(
440+
solution["var times s"].sensitivity["s"],
441+
(-1 + 0.2 * solution.t)[:, np.newaxis],
442+
)
443+
np.testing.assert_allclose(
444+
solution["var times s"].sensitivity["all"],
445+
np.hstack(
446+
[
447+
solution["var times s"].sensitivity["p"],
448+
solution["var times s"].sensitivity["q"],
449+
solution["var times s"].sensitivity["r"],
450+
solution["var times s"].sensitivity["s"],
451+
]
452+
),
453+
)
454+
386455
@unittest.skip("")
387456
def test_solve_sensitivity_vector_var_scalar_input(self):
388457
var = pybamm.Variable("var", "negative electrode")
389458
model = pybamm.BaseModel()
459+
# Set length scales to avoid warning
460+
model.length_scales = {"negative electrode": 1}
390461
param = pybamm.InputParameter("param")
391462
model.rhs = {var: -param * var}
392463
model.initial_conditions = {var: 2}
@@ -410,32 +481,152 @@ def test_solve_sensitivity_vector_var_scalar_input(self):
410481
decimal=4,
411482
)
412483

484+
# More complicated model
485+
# Create model
486+
model = pybamm.BaseModel()
487+
# Set length scales to avoid warning
488+
model.length_scales = {"negative electrode": 1}
489+
var = pybamm.Variable("var", "negative electrode")
490+
p = pybamm.InputParameter("p")
491+
q = pybamm.InputParameter("q")
492+
r = pybamm.InputParameter("r")
493+
s = pybamm.InputParameter("s")
494+
model.rhs = {var: p * q}
495+
model.initial_conditions = {var: r}
496+
model.variables = {"var times s": var * s}
497+
498+
# Discretise
499+
disc.process_model(model)
500+
501+
# Solve
502+
# Make sure that passing in extra options works
503+
solver = pybamm.ScipySolver(
504+
rtol=1e-10, atol=1e-10, solve_sensitivity_equations=True
505+
)
506+
t_eval = np.linspace(0, 1, 80)
507+
solution = solver.solve(
508+
model, t_eval, inputs={"p": 0.1, "q": 2, "r": -1, "s": 0.5}
509+
)
510+
np.testing.assert_allclose(solution.y, np.tile(-1 + 0.2 * solution.t, (n, 1)))
511+
np.testing.assert_allclose(
512+
solution.sensitivity["p"], np.repeat(2 * solution.t, n)[:, np.newaxis],
513+
)
514+
np.testing.assert_allclose(
515+
solution.sensitivity["q"], np.repeat(0.1 * solution.t, n)[:, np.newaxis],
516+
)
517+
np.testing.assert_allclose(solution.sensitivity["r"], 1)
518+
np.testing.assert_allclose(solution.sensitivity["s"], 0)
519+
np.testing.assert_allclose(
520+
solution.sensitivity["all"],
521+
np.hstack(
522+
[
523+
solution.sensitivity["p"],
524+
solution.sensitivity["q"],
525+
solution.sensitivity["r"],
526+
solution.sensitivity["s"],
527+
]
528+
),
529+
)
530+
np.testing.assert_allclose(
531+
solution["var times s"].data, np.tile(0.5 * (-1 + 0.2 * solution.t), (n, 1))
532+
)
533+
np.testing.assert_allclose(
534+
solution["var times s"].sensitivity["p"],
535+
np.repeat(0.5 * (2 * solution.t), n)[:, np.newaxis],
536+
)
537+
np.testing.assert_allclose(
538+
solution["var times s"].sensitivity["q"],
539+
np.repeat(0.5 * (0.1 * solution.t), n)[:, np.newaxis],
540+
)
541+
np.testing.assert_allclose(solution["var times s"].sensitivity["r"], 0.5)
542+
np.testing.assert_allclose(
543+
solution["var times s"].sensitivity["s"],
544+
np.repeat(-1 + 0.2 * solution.t, n)[:, np.newaxis],
545+
)
546+
np.testing.assert_allclose(
547+
solution["var times s"].sensitivity["all"],
548+
np.hstack(
549+
[
550+
solution["var times s"].sensitivity["p"],
551+
solution["var times s"].sensitivity["q"],
552+
solution["var times s"].sensitivity["r"],
553+
solution["var times s"].sensitivity["s"],
554+
]
555+
),
556+
)
557+
413558
def test_solve_sensitivity_scalar_var_vector_input(self):
414559
var = pybamm.Variable("var", "negative electrode")
415560
model = pybamm.BaseModel()
561+
# Set length scales to avoid warning
562+
model.length_scales = {"negative electrode": 1}
563+
416564
param = pybamm.InputParameter("param", "negative electrode")
417565
model.rhs = {var: -param * var}
418566
model.initial_conditions = {var: 2}
419-
model.variables = {"x-average of var": pybamm.x_average(var)}
567+
model.variables = {
568+
"var": var,
569+
"integral of var": pybamm.Integral(var, pybamm.standard_spatial_vars.x_n),
570+
}
420571

421572
# create discretisation
422-
mesh = get_mesh_for_testing(xpts=5)
573+
mesh = get_mesh_for_testing()
423574
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
424575
disc = pybamm.Discretisation(mesh, spatial_methods)
425576
disc.process_model(model)
426577
n = disc.mesh["negative electrode"].npts
427578

428-
# Solve - scalar input
429-
solver = pybamm.ScipySolver(solve_sensitivity_equations=True)
430-
t_eval = np.linspace(0, 1, 3)
579+
# Solve - constant input
580+
solver = pybamm.ScipySolver(
581+
rtol=1e-10, atol=1e-10, solve_sensitivity_equations=True
582+
)
583+
t_eval = np.linspace(0, 1)
431584
solution = solver.solve(model, t_eval, inputs={"param": 7 * np.ones(n)})
585+
l_n = mesh["negative electrode"].edges[-1]
432586
np.testing.assert_array_almost_equal(
433587
solution["var"].data, np.tile(2 * np.exp(-7 * t_eval), (n, 1)), decimal=4,
434588
)
589+
435590
np.testing.assert_array_almost_equal(
436591
solution["var"].sensitivity["param"],
437-
np.repeat(-2 * t_eval * np.exp(-7 * t_eval), n)[:, np.newaxis],
438-
decimal=4,
592+
np.vstack([np.eye(n) * -2 * t * np.exp(-7 * t) for t in t_eval]),
593+
)
594+
np.testing.assert_array_almost_equal(
595+
solution["integral of var"].data, 2 * np.exp(-7 * t_eval) * l_n, decimal=4,
596+
)
597+
np.testing.assert_array_almost_equal(
598+
solution["integral of var"].sensitivity["param"],
599+
np.tile(-2 * t_eval * np.exp(-7 * t_eval) * l_n / 40, (40, 1)).T,
600+
)
601+
602+
# Solve - linspace input
603+
solver = pybamm.ScipySolver(
604+
rtol=1e-10, atol=1e-10, solve_sensitivity_equations=True
605+
)
606+
t_eval = np.linspace(0, 1)
607+
p_eval = np.linspace(1, 2, n)
608+
solution = solver.solve(model, t_eval, inputs={"param": p_eval})
609+
l_n = mesh["negative electrode"].edges[-1]
610+
np.testing.assert_array_almost_equal(
611+
solution["var"].data, 2 * np.exp(-p_eval[:, np.newaxis] * t_eval), decimal=4
612+
)
613+
np.testing.assert_array_almost_equal(
614+
solution["var"].sensitivity["param"],
615+
np.vstack([np.diag(-2 * t * np.exp(-p_eval * t)) for t in t_eval]),
616+
)
617+
618+
np.testing.assert_array_almost_equal(
619+
solution["integral of var"].data,
620+
np.sum(
621+
2
622+
* np.exp(-p_eval[:, np.newaxis] * t_eval)
623+
* mesh["negative electrode"].d_edges[:, np.newaxis],
624+
axis=0,
625+
),
626+
)
627+
np.testing.assert_array_almost_equal(
628+
solution["integral of var"].sensitivity["param"],
629+
np.vstack([-2 * t * np.exp(-p_eval * t) * l_n / 40 for t in t_eval]),
439630
)
440631

441632

0 commit comments

Comments
 (0)