Skip to content

Commit f8bc091

Browse files
committed
#1477 improve coverage
1 parent 1b32660 commit f8bc091

File tree

5 files changed

+77
-193
lines changed

5 files changed

+77
-193
lines changed

pybamm/solvers/base_solver.py

+21-17
Original file line numberDiff line numberDiff line change
@@ -607,17 +607,26 @@ def jacp(*args, **kwargs):
607607
)
608608

609609
# if we have changed the equations to include the explicit sensitivity
610-
# equations, then we also need to update the mass matrix
610+
# equations, then we also need to update the mass matrix and bounds
611611
if calculate_sensitivities_explicit:
612-
n_inputs = model.len_rhs_sens // model.len_rhs
612+
if model.len_rhs != 0:
613+
n_inputs = model.len_rhs_sens // model.len_rhs
614+
elif model.len_alg != 0:
615+
n_inputs = model.len_alg_sens // model.len_alg
616+
model.bounds = (
617+
np.repeat(model.bounds[0], n_inputs + 1),
618+
np.repeat(model.bounds[1], n_inputs + 1),
619+
)
613620
if (model.mass_matrix is not None
614621
and model.mass_matrix.shape[0] == model.len_rhs_and_alg):
615-
model.mass_matrix_inv = pybamm.Matrix(
616-
block_diag(
617-
[model.mass_matrix_inv.entries] * (n_inputs + 1),
618-
format="csr"
622+
623+
if model.mass_matrix_inv is not None:
624+
model.mass_matrix_inv = pybamm.Matrix(
625+
block_diag(
626+
[model.mass_matrix_inv.entries] * (n_inputs + 1),
627+
format="csr"
628+
)
619629
)
620-
)
621630
model.mass_matrix = pybamm.Matrix(
622631
block_diag(
623632
[model.mass_matrix.entries] * (n_inputs + 1), format="csr"
@@ -627,10 +636,11 @@ def jacp(*args, **kwargs):
627636
# take care if calculate_sensitivites used then not used
628637
if (model.mass_matrix is not None and
629638
model.mass_matrix.shape[0] > model.len_rhs_and_alg):
630-
model.mass_matrix_inv = pybamm.Matrix(
631-
model.mass_matrix_inv.entries[:model.len_rhs,
632-
:model.len_rhs]
633-
)
639+
if model.mass_matrix_inv is not None:
640+
model.mass_matrix_inv = pybamm.Matrix(
641+
model.mass_matrix_inv.entries[:model.len_rhs,
642+
:model.len_rhs]
643+
)
634644
model.mass_matrix = pybamm.Matrix(
635645
model.mass_matrix.entries[:model.len_rhs_and_alg,
636646
:model.len_rhs_and_alg]
@@ -1428,12 +1438,6 @@ def _set_up_ext_and_inputs(self, model, external_variables, inputs,
14281438
for input_param in model.input_parameters:
14291439
name = input_param.name
14301440
if name not in inputs:
1431-
# Don't allow symbolic inputs if using `sensitivity`
1432-
if calculate_sensitivities:
1433-
raise pybamm.SolverError(
1434-
"Cannot have symbolic inputs if explicitly solving forward"
1435-
"sensitivity equations"
1436-
)
14371441
# Only allow symbolic inputs for CasadiSolver and CasadiAlgebraicSolver
14381442
if not isinstance(
14391443
self, (pybamm.CasadiSolver, pybamm.CasadiAlgebraicSolver)

pybamm/solvers/casadi_algebraic_solver.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,11 @@ def _integrate(self, model, t_eval, inputs_dict=None):
229229
y_sol = casadi.vertcat(y_diff, y_alg)
230230

231231
# Return solution object (no events, so pass None to t_event, y_event)
232+
233+
explicit_sensitivities = bool(model.calculate_sensitivities)
232234
sol = pybamm.Solution(
233-
[t_eval], y_sol, model, inputs_dict, termination="success"
235+
[t_eval], y_sol, model, inputs_dict, termination="success",
236+
sensitivities=explicit_sensitivities
234237
)
235238
sol.integration_time = integration_time
236239
return sol

pybamm/solvers/casadi_solver.py

+1-12
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,6 @@ def _integrate(self, model, t_eval, inputs_dict=None):
123123
Any external variables or input parameters to pass to the model when solving
124124
"""
125125

126-
# are we solving explicit forward equations?
127-
explicit_sensitivities = bool(model.calculate_sensitivities)
128-
129126
# Record whether there are any symbolic inputs
130127
inputs_dict = inputs_dict or {}
131128
has_symbolic_inputs = any(
@@ -275,10 +272,6 @@ def _integrate(self, model, t_eval, inputs_dict=None):
275272
# update y0
276273
y0 = solution.all_ys[-1][:, -1]
277274

278-
# now we extract sensitivities from the solution
279-
if (explicit_sensitivities):
280-
solution.extract_explicit_sensitivities()
281-
282275
return solution
283276

284277
def _solve_for_event(self, coarse_solution, init_event_signs):
@@ -544,11 +537,7 @@ def create_integrator(self, model, inputs, t_eval=None, use_event_switch=False):
544537
# set up and solve
545538
t = casadi.MX.sym("t")
546539
p = casadi.MX.sym("p", inputs.shape[0])
547-
# If the initial conditions depend on inputs, evaluate the function
548-
if isinstance(model.y0, casadi.Function):
549-
y0 = model.y0(p)
550-
else:
551-
y0 = model.y0
540+
y0 = model.y0
552541

553542
y_diff = casadi.MX.sym("y_diff", rhs(0, y0, p).shape[0])
554543
y_alg = casadi.MX.sym("y_alg", algebraic(0, y0, p).shape[0])

pybamm/solvers/solution.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,10 @@ def _extract_explicit_sensitivities(self, model, y, t_eval, inputs):
182182
n_rhs = model.len_rhs
183183
n_alg = model.len_alg
184184
# Get the point where the algebraic equations start
185-
n_p = model.len_rhs_sens // model.len_rhs
185+
if model.len_rhs != 0:
186+
n_p = model.len_rhs_sens // model.len_rhs
187+
else:
188+
n_p = model.len_alg_sens // model.len_alg
186189
len_rhs_and_sens = model.len_rhs + model.len_rhs_sens
187190

188191
n_t = len(t_eval)

tests/unit/test_solvers/test_casadi_solver.py

+47-162
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,38 @@ def test_solve_sensitivity_scalar_var_vector_input(self):
817817
np.vstack([-2 * t * np.exp(-p_eval * t) * l_n / n for t in t_eval]),
818818
)
819819

820+
def test_solve_sensitivity_then_no_sensitivity(self):
821+
# Create model
822+
model = pybamm.BaseModel()
823+
var = pybamm.Variable("var")
824+
p = pybamm.InputParameter("p")
825+
model.rhs = {var: p * var}
826+
model.initial_conditions = {var: 1}
827+
model.variables = {"var squared": var ** 2}
828+
829+
# Solve
830+
# Make sure that passing in extra options works
831+
solver = pybamm.CasadiSolver(
832+
mode="fast", rtol=1e-10, atol=1e-10
833+
)
834+
t_eval = np.linspace(0, 1, 80)
835+
solution = solver.solve(model, t_eval, inputs={"p": 0.1},
836+
calculate_sensitivities=True)
837+
838+
# check sensitivities
839+
np.testing.assert_allclose(
840+
solution.sensitivities["p"],
841+
(solution.t * np.exp(0.1 * solution.t))[:, np.newaxis],
842+
)
843+
844+
solution = solver.solve(model, t_eval, inputs={"p": 0.1})
845+
846+
np.testing.assert_array_equal(solution.t, t_eval)
847+
np.testing.assert_allclose(solution.y, np.exp(0.1 * solution.t).reshape(1, -1))
848+
np.testing.assert_allclose(
849+
solution["var squared"].data, np.exp(0.1 * solution.t) ** 2
850+
)
851+
820852

821853
class TestCasadiSolverDAEsWithForwardSensitivityEquations(unittest.TestCase):
822854
def test_solve_sensitivity_scalar_var_scalar_input(self):
@@ -858,180 +890,33 @@ def test_solve_sensitivity_scalar_var_scalar_input(self):
858890
atol=1e-7
859891
)
860892

861-
def test_solve_sensitivity_vector_var_scalar_input(self):
862-
var = pybamm.Variable("var", "negative electrode")
863-
model = pybamm.BaseModel()
864-
# Set length scales to avoid warning
865-
model.length_scales = {"negative electrode": 1}
866-
param = pybamm.InputParameter("param")
867-
model.rhs = {var: -param * var}
868-
model.initial_conditions = {var: 2}
869-
model.variables = {"var": var}
870-
871-
# create discretisation
872-
disc = get_discretisation_for_testing()
873-
disc.process_model(model)
874-
n = disc.mesh["negative electrode"].npts
875-
876-
# Solve - scalar input
877-
solver = pybamm.CasadiSolver()
878-
t_eval = np.linspace(0, 1)
879-
solution = solver.solve(model, t_eval, inputs={"param": 7},
880-
calculate_sensitivities=["param"])
881-
np.testing.assert_array_almost_equal(
882-
solution["var"].data, np.tile(2 * np.exp(-7 * t_eval), (n, 1)), decimal=4,
883-
)
884-
np.testing.assert_array_almost_equal(
885-
solution["var"].sensitivities["param"],
886-
np.repeat(-2 * t_eval * np.exp(-7 * t_eval), n)[:, np.newaxis],
887-
decimal=4,
888-
)
889-
890-
# More complicated model
893+
def test_solve_sensitivity_algebraic(self):
891894
# Create model
892895
model = pybamm.BaseModel()
893-
# Set length scales to avoid warning
894-
model.length_scales = {"negative electrode": 1}
895-
var = pybamm.Variable("var", "negative electrode")
896+
var = pybamm.Variable("var")
896897
p = pybamm.InputParameter("p")
897-
q = pybamm.InputParameter("q")
898-
r = pybamm.InputParameter("r")
899-
s = pybamm.InputParameter("s")
900-
model.rhs = {var: p * q}
901-
model.initial_conditions = {var: r}
902-
model.variables = {"var times s": var * s}
903-
904-
# Discretise
905-
disc.process_model(model)
898+
model.algebraic = {var: var - p * pybamm.t}
899+
model.initial_conditions = {var: 0}
900+
model.variables = {"var squared": var ** 2}
906901

907902
# Solve
908903
# Make sure that passing in extra options works
909-
solver = pybamm.CasadiSolver(
910-
rtol=1e-10, atol=1e-10,
911-
)
904+
solver = pybamm.CasadiAlgebraicSolver(tol=1e-10)
912905
t_eval = np.linspace(0, 1, 80)
913-
solution = solver.solve(
914-
model, t_eval, inputs={"p": 0.1, "q": 2, "r": -1, "s": 0.5},
915-
calculate_sensitivities=True,
916-
)
917-
np.testing.assert_allclose(solution.y, np.tile(-1 + 0.2 * solution.t, (n, 1)))
918-
np.testing.assert_allclose(
919-
solution.sensitivities["p"], np.repeat(2 * solution.t, n)[:, np.newaxis],
920-
)
921-
np.testing.assert_allclose(
922-
solution.sensitivities["q"], np.repeat(0.1 * solution.t, n)[:, np.newaxis],
923-
)
924-
np.testing.assert_allclose(solution.sensitivities["r"], 1)
925-
np.testing.assert_allclose(solution.sensitivities["s"], 0)
926-
np.testing.assert_allclose(
927-
solution.sensitivities["all"],
928-
np.hstack(
929-
[
930-
solution.sensitivities["p"],
931-
solution.sensitivities["q"],
932-
solution.sensitivities["r"],
933-
solution.sensitivities["s"],
934-
]
935-
),
936-
)
906+
solution = solver.solve(model, t_eval, inputs={"p": 0.1},
907+
calculate_sensitivities=True)
908+
np.testing.assert_array_equal(solution.t, t_eval)
909+
np.testing.assert_allclose(solution.y[0], 0.1 * solution.t)
937910
np.testing.assert_allclose(
938-
solution["var times s"].data, np.tile(0.5 * (-1 + 0.2 * solution.t), (n, 1))
911+
solution.sensitivities["p"], solution.t.reshape(-1, 1), atol=1e-7
939912
)
940913
np.testing.assert_allclose(
941-
solution["var times s"].sensitivities["p"],
942-
np.repeat(0.5 * (2 * solution.t), n)[:, np.newaxis],
914+
solution["var squared"].data, (0.1 * solution.t) ** 2
943915
)
944916
np.testing.assert_allclose(
945-
solution["var times s"].sensitivities["q"],
946-
np.repeat(0.5 * (0.1 * solution.t), n)[:, np.newaxis],
947-
)
948-
np.testing.assert_allclose(solution["var times s"].sensitivities["r"], 0.5)
949-
np.testing.assert_allclose(
950-
solution["var times s"].sensitivities["s"],
951-
np.repeat(-1 + 0.2 * solution.t, n)[:, np.newaxis],
952-
)
953-
np.testing.assert_allclose(
954-
solution["var times s"].sensitivities["all"],
955-
np.hstack(
956-
[
957-
solution["var times s"].sensitivities["p"],
958-
solution["var times s"].sensitivities["q"],
959-
solution["var times s"].sensitivities["r"],
960-
solution["var times s"].sensitivities["s"],
961-
]
962-
),
963-
)
964-
965-
def test_solve_sensitivity_scalar_var_vector_input(self):
966-
var = pybamm.Variable("var", "negative electrode")
967-
model = pybamm.BaseModel()
968-
# Set length scales to avoid warning
969-
model.length_scales = {"negative electrode": 1}
970-
971-
param = pybamm.InputParameter("param", "negative electrode")
972-
model.rhs = {var: -param * var}
973-
model.initial_conditions = {var: 2}
974-
model.variables = {
975-
"var": var,
976-
"integral of var": pybamm.Integral(var, pybamm.standard_spatial_vars.x_n),
977-
}
978-
979-
# create discretisation
980-
mesh = get_mesh_for_testing(xpts=5)
981-
spatial_methods = {"macroscale": pybamm.FiniteVolume()}
982-
disc = pybamm.Discretisation(mesh, spatial_methods)
983-
disc.process_model(model)
984-
n = disc.mesh["negative electrode"].npts
985-
986-
# Solve - constant input
987-
solver = pybamm.CasadiSolver(
988-
mode="fast", rtol=1e-10, atol=1e-10
989-
)
990-
t_eval = np.linspace(0, 1)
991-
solution = solver.solve(model, t_eval, inputs={"param": 7 * np.ones(n)},
992-
calculate_sensitivities=True)
993-
l_n = mesh["negative electrode"].edges[-1]
994-
np.testing.assert_array_almost_equal(
995-
solution["var"].data, np.tile(2 * np.exp(-7 * t_eval), (n, 1)), decimal=4,
996-
)
997-
998-
np.testing.assert_array_almost_equal(
999-
solution["var"].sensitivities["param"],
1000-
np.vstack([np.eye(n) * -2 * t * np.exp(-7 * t) for t in t_eval]),
1001-
)
1002-
np.testing.assert_array_almost_equal(
1003-
solution["integral of var"].data, 2 * np.exp(-7 * t_eval) * l_n, decimal=4,
1004-
)
1005-
np.testing.assert_array_almost_equal(
1006-
solution["integral of var"].sensitivities["param"],
1007-
np.tile(-2 * t_eval * np.exp(-7 * t_eval) * l_n / n, (n, 1)).T,
1008-
)
1009-
1010-
# Solve - linspace input
1011-
p_eval = np.linspace(1, 2, n)
1012-
solution = solver.solve(model, t_eval, inputs={"param": p_eval},
1013-
calculate_sensitivities=True)
1014-
l_n = mesh["negative electrode"].edges[-1]
1015-
np.testing.assert_array_almost_equal(
1016-
solution["var"].data, 2 * np.exp(-p_eval[:, np.newaxis] * t_eval), decimal=4
1017-
)
1018-
np.testing.assert_array_almost_equal(
1019-
solution["var"].sensitivities["param"],
1020-
np.vstack([np.diag(-2 * t * np.exp(-p_eval * t)) for t in t_eval]),
1021-
)
1022-
1023-
np.testing.assert_array_almost_equal(
1024-
solution["integral of var"].data,
1025-
np.sum(
1026-
2
1027-
* np.exp(-p_eval[:, np.newaxis] * t_eval)
1028-
* mesh["negative electrode"].d_edges[:, np.newaxis],
1029-
axis=0,
1030-
),
1031-
)
1032-
np.testing.assert_array_almost_equal(
1033-
solution["integral of var"].sensitivities["param"],
1034-
np.vstack([-2 * t * np.exp(-p_eval * t) * l_n / n for t in t_eval]),
917+
solution["var squared"].sensitivities["p"],
918+
(2 * 0.1 * solution.t ** 2).reshape(-1, 1),
919+
atol=1e-7
1035920
)
1036921

1037922

0 commit comments

Comments
 (0)