Skip to content

Commit 7c39f3f

Browse files
committed
#1477 update test_sensitivities to use a dae
1 parent 7d97c45 commit 7c39f3f

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed

tests/unit/test_solvers/test_base_solver.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -324,37 +324,43 @@ def test_extrapolation_warnings(self):
324324

325325
def test_sensitivities(self):
326326

327-
def exact_diff_a(v, a, b):
328-
return np.array([v**2 + 2 * a])
327+
def exact_diff_a(y, a, b):
328+
return np.array([
329+
[y[0]**2 + 2 * a],
330+
[y[0]]
331+
])
329332

330-
def exact_diff_b(v, a, b):
331-
return np.array([v])
333+
def exact_diff_b(y, a, b):
334+
return np.array([[y[0]], [0]])
332335

333-
for f in ['', 'python', 'casadi', 'jax']:
336+
for convert_to_format in ['', 'python', 'casadi', 'jax']:
334337
model = pybamm.BaseModel()
335338
v = pybamm.Variable("v")
339+
u = pybamm.Variable("u")
336340
a = pybamm.InputParameter("a")
337341
b = pybamm.InputParameter("b")
338342
model.rhs = {v: a * v**2 + b * v + a**2}
339-
model.initial_conditions = {v: 1}
340-
model.convert_to_format = f
341-
solver = pybamm.ScipySolver()
343+
model.algebraic = {u: a * v - u}
344+
model.initial_conditions = {v: 1, u: a * 1}
345+
model.convert_to_format = convert_to_format
346+
solver = pybamm.CasadiSolver()
342347
solver.set_up(model, calculate_sensitivites=True,
343348
inputs={'a': 0, 'b': 0})
344349
all_inputs = []
345350
for v_value in [0.1, -0.2, 1.5, 8.4]:
346-
for a_value in [0.12, 1.5]:
347-
for b_value in [0.82, 1.9]:
348-
y = np.array([v_value])
349-
t = 0
350-
inputs = {'a': a_value, 'b': b_value}
351-
all_inputs.append((t, y, inputs))
351+
for u_value in [0.13, -0.23, 1.3, 13.4]:
352+
for a_value in [0.12, 1.5]:
353+
for b_value in [0.82, 1.9]:
354+
y = np.array([v_value, u_value])
355+
t = 0
356+
inputs = {'a': a_value, 'b': b_value}
357+
all_inputs.append((t, y, inputs))
352358
for t, y, inputs in all_inputs:
353-
if f == 'casadi':
359+
if model.convert_to_format == 'casadi':
354360
use_inputs = casadi.vertcat(*[x for x in inputs.values()])
355361
else:
356362
use_inputs = inputs
357-
if f == 'jax':
363+
if model.convert_to_format == 'jax':
358364
sens = model.sensitivities_eval(
359365
t, y, use_inputs
360366
)

tests/unit/test_solvers/test_idaklu_solver.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ def test_ida_roberts_klu_sensitivities(self):
6868

6969
t_eval = np.linspace(0, 3, 100)
7070
a_value = 0.1
71-
sol = solver.solve(model, t_eval, inputs={"a": a_value})
71+
sol = solver.solve(
72+
model, t_eval, inputs={"a": a_value},
73+
calculate_sensitivities=True
74+
)
7275

7376
# test that final time is time of event
7477
# y = 0.1 t + y0 so y=0.2 when t=2

0 commit comments

Comments
 (0)