Skip to content

Commit 88ecb3f

Browse files
committed
#1477 do sensitivity integration tests using a processed variable
1 parent bf59b7d commit 88ecb3f

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

pybamm/solvers/processed_variable.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -537,17 +537,18 @@ def initialise_sensitivity_explicit_forward(self):
537537
dvar_dp_func = casadi.Function(
538538
"dvar_dp", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dp]
539539
)
540-
for idx in range(len(self.all_ts[0])):
541-
t = self.all_ts[0][idx]
542-
u = self.all_ys[0][:, idx]
543-
next_dvar_dy_eval = dvar_dy_func(t, u, inputs_stacked)
544-
next_dvar_dp_eval = dvar_dp_func(t, u, inputs_stacked)
545-
if idx == 0:
546-
dvar_dy_eval = next_dvar_dy_eval
547-
dvar_dp_eval = next_dvar_dp_eval
548-
else:
549-
dvar_dy_eval = casadi.diagcat(dvar_dy_eval, next_dvar_dy_eval)
550-
dvar_dp_eval = casadi.vertcat(dvar_dp_eval, next_dvar_dp_eval)
540+
for index, (ts, ys) in enumerate(zip(self.all_ts, self.all_ys)):
541+
for idx in range(len(ts)):
542+
t = ts[idx]
543+
u = ys[:, idx]
544+
next_dvar_dy_eval = dvar_dy_func(t, u, inputs_stacked)
545+
next_dvar_dp_eval = dvar_dp_func(t, u, inputs_stacked)
546+
if index == 0 and idx == 0:
547+
dvar_dy_eval = next_dvar_dy_eval
548+
dvar_dp_eval = next_dvar_dp_eval
549+
else:
550+
dvar_dy_eval = casadi.diagcat(dvar_dy_eval, next_dvar_dy_eval)
551+
dvar_dp_eval = casadi.vertcat(dvar_dp_eval, next_dvar_dp_eval)
551552

552553
# Compute sensitivity
553554
dy_dp = self.solution_sensitivities["all"]

tests/integration/test_models/standard_model_tests.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def test_outputs(self):
9292
)
9393
std_out_test.test_all()
9494

95-
def test_sensitivities(self, param_name, param_value):
95+
def test_sensitivities(self, param_name, param_value,
96+
output_name='Terminal voltage [V]'):
9697
self.parameter_values.update({param_name: "[input]"})
9798
inputs = {param_name: param_value}
9899

@@ -113,6 +114,7 @@ def test_sensitivities(self, param_name, param_value):
113114
self.model, t_eval, inputs=inputs,
114115
calculate_sensitivities=True
115116
)
117+
output_sens = self.solution[output_name].sensitivities[param_name]
116118

117119
# check via finite differencing
118120
h = 1e-6 * param_value
@@ -121,14 +123,16 @@ def test_sensitivities(self, param_name, param_value):
121123
sol_plus = self.solver.solve(
122124
self.model, t_eval, inputs=inputs_plus,
123125
)
126+
output_plus = sol_plus[output_name](t=t_eval)
124127
sol_neg = self.solver.solve(
125128
self.model, t_eval, inputs=inputs_neg
126129
)
127-
fd = ((np.array(sol_plus.y) - np.array(sol_neg.y)) / h)
130+
output_neg = sol_neg[output_name](t=t_eval)
131+
fd = ((np.array(output_plus) - np.array(output_neg)) / h)
128132
fd = fd.transpose().reshape(-1, 1)
129133
np.testing.assert_allclose(
130-
self.solution.sensitivities[param_name], fd,
131-
rtol=1e-1, atol=1e-5,
134+
output_sens, fd,
135+
rtol=1e-2, atol=1e-6,
132136
)
133137

134138
def test_all(

0 commit comments

Comments
 (0)