Skip to content

Commit 06984c4

Browse files
committedAug 2, 2021
#1477 check sensitivities with fd in integration tests
1 parent 03528da commit 06984c4

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed
 

‎pybamm/solvers/base_solver.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -608,18 +608,33 @@ def jacp(*args, **kwargs):
608608

609609
# if we have changed the equations to include the explicit sensitivity
610610
# equations, then we also need to update the mass matrix
611+
n_inputs = model.len_rhs_sens // model.len_rhs
612+
n_state_without_sens = model.len_rhs_and_alg
611613
if calculate_sensitivities_explicit:
612-
n_inputs = model.len_rhs_sens // model.len_rhs
613-
model.mass_matrix_inv = pybamm.Matrix(
614-
block_diag(
615-
[model.mass_matrix_inv.entries] * (n_inputs + 1), format="csr"
614+
if model.mass_matrix.shape[0] == n_state_without_sens:
615+
model.mass_matrix_inv = pybamm.Matrix(
616+
block_diag(
617+
[model.mass_matrix_inv.entries] * (n_inputs + 1),
618+
format="csr"
619+
)
616620
)
617-
)
618-
model.mass_matrix = pybamm.Matrix(
619-
block_diag(
620-
[model.mass_matrix.entries] * (n_inputs + 1), format="csr"
621+
model.mass_matrix = pybamm.Matrix(
622+
block_diag(
623+
[model.mass_matrix.entries] * (n_inputs + 1), format="csr"
624+
)
625+
)
626+
else:
627+
# take care if calculate_sensitivites used then not used
628+
n_state_with_sens = model.len_rhs_and_alg * (n_inputs + 1)
629+
if model.mass_matrix.shape[0] == n_state_with_sens:
630+
model.mass_matrix_inv = pybamm.Matrix(
631+
model.mass_matrix_inv.entries[:n_state_without_sens,
632+
:n_state_without_sens]
633+
)
634+
model.mass_matrix = pybamm.Matrix(
635+
model.mass_matrix.entries[:n_state_without_sens,
636+
:n_state_without_sens]
621637
)
622-
)
623638

624639
# Save CasADi functions for the CasADi solver
625640
# Note: when we pass to casadi the ode part of the problem must be in

‎tests/integration/test_models/standard_model_tests.py

+17
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,25 @@ def test_sensitivities(self):
101101

102102
self.test_processing_parameters()
103103
self.test_processing_disc()
104+
104105
self.test_solving(inputs=inputs, calculate_sensitivities=True)
105106

107+
# check via finite differencing
108+
h = 1e-6
109+
inputs_plus = {param_name: neg_electrode_cond + 0.5 * h}
110+
inputs_neg = {param_name: neg_electrode_cond - 0.5 * h}
111+
sol_plus = self.solver.solve(
112+
self.model, self.solution.all_ts[0], inputs=inputs_plus
113+
)
114+
sol_neg = self.solver.solve(
115+
self.model, self.solution.all_ts[0], inputs=inputs_neg
116+
)
117+
n = self.solution.sensitivities[param_name].shape[0]
118+
np.testing.assert_array_almost_equal(
119+
self.solution.sensitivities[param_name],
120+
((sol_plus.y - sol_neg.y) / h).reshape((n, 1))
121+
)
122+
106123
if (
107124
isinstance(
108125
self.model, (pybamm.lithium_ion.BaseModel, pybamm.lead_acid.BaseModel)

0 commit comments

Comments
 (0)
Please sign in to comment.