Skip to content

Commit 7d97c45

Browse files
committed
#1477 evaluating sensitivities ok for all convert_tos
1 parent 41565da commit 7d97c45

File tree

3 files changed

+72
-23
lines changed

3 files changed

+72
-23
lines changed

pybamm/expression_tree/operations/evaluate.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def get_sensitivities(self):
628628
self._sens_evaluate = jax.jit(jacobian_evaluate,
629629
static_argnums=self._static_argnums)
630630

631-
return EvaluatorJaxJacobian(self._jac_evaluate, self._constants)
631+
return EvaluatorJaxSensitivities(self._sens_evaluate, self._constants)
632632

633633

634634
def debug(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
@@ -687,3 +687,25 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
687687
return result, known_evals
688688
else:
689689
return result
690+
691+
class EvaluatorJaxSensitivities:
692+
def __init__(self, jac_evaluate, constants):
693+
self._jac_evaluate = jac_evaluate
694+
self._constants = constants
695+
696+
def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
697+
"""
698+
Acts as a drop-in replacement for :func:`pybamm.Symbol.evaluate`
699+
"""
700+
# generated code assumes y is a column vector
701+
if y is not None and y.ndim == 1:
702+
y = y.reshape(-1, 1)
703+
704+
# execute code
705+
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
706+
707+
# don't need known_evals, but need to reproduce Symbol.evaluate signature
708+
if known_evals is not None:
709+
return result, known_evals
710+
else:
711+
return result

pybamm/solvers/base_solver.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def report(string):
269269
f"to parameters {calculate_sensitivites} using jax"
270270
))
271271
jacp_dict = func.get_sensitivities()
272+
jacp_dict = jacp_dict.evaluate
272273
else:
273274
jacp_dict = None
274275
if use_jacobian:
@@ -283,13 +284,11 @@ def report(string):
283284
elif model.convert_to_format != "casadi":
284285
# Process with pybamm functions, optionally converting
285286
# to python evaluator
286-
print('calculate_sensitivites = ', calculate_sensitivites)
287287
if calculate_sensitivites:
288288
report((
289289
f"Calculating sensitivities for {name} with respect "
290290
f"to parameters {calculate_sensitivites}"
291291
))
292-
print(type(func))
293292
jacp_dict = {
294293
p: func.diff(pybamm.InputParameter(p))
295294
for p in calculate_sensitivites
@@ -360,10 +359,17 @@ def report(string):
360359
else:
361360
jac_call = None
362361
if jacp_dict is not None:
363-
jacp_call = {
364-
k: SolverCallable(v, name + "_sensitivity_wrt_" + k, model)
365-
for k, v in jacp_dict.items()
366-
}
362+
if model.convert_to_format == "jax":
363+
jacp_call = SolverCallable(
364+
jacp_dict, name + "_sensitivity_wrt_inputs", model
365+
)
366+
else:
367+
jacp_call = {
368+
k: SolverCallable(v, name + "_sensitivity_wrt_" + k, model)
369+
for k, v in jacp_dict.items()
370+
}
371+
else:
372+
jacp_call = None
367373
return func, func_call, jac_call, jacp_call
368374

369375
# Check for heaviside and modulo functions in rhs and algebraic and add
@@ -520,6 +526,8 @@ def report(string):
520526
"rhs", [t_casadi, y_casadi, p_casadi_stacked], [explicit_rhs]
521527
)
522528
model.casadi_algebraic = algebraic
529+
model.casadi_sensitivities_rhs = jacp_rhs
530+
model.casadi_sensitivities_algebraic = jacp_algebraic
523531
if len(model.rhs) == 0:
524532
# No rhs equations: residuals is algebraic only
525533
model.residuals_eval = Residuals(algebraic, "residuals", model)

tests/unit/test_solvers/test_base_solver.py

+35-16
Original file line numberDiff line numberDiff line change
@@ -323,15 +323,14 @@ def test_extrapolation_warnings(self):
323323
solver.solve(model, t_eval=[0, 1])
324324

325325
def test_sensitivities(self):
326-
pybamm.set_logging_level('DEBUG')
327326

328327
def exact_diff_a(v, a, b):
329-
return v**2 + 2 * a
328+
return np.array([v**2 + 2 * a])
330329

331330
def exact_diff_b(v, a, b):
332-
return v
331+
return np.array([v])
333332

334-
for f in ['', 'python', 'casadi']:
333+
for f in ['', 'python', 'casadi', 'jax']:
335334
model = pybamm.BaseModel()
336335
v = pybamm.Variable("v")
337336
a = pybamm.InputParameter("a")
@@ -342,25 +341,45 @@ def exact_diff_b(v, a, b):
342341
solver = pybamm.ScipySolver()
343342
solver.set_up(model, calculate_sensitivites=True,
344343
inputs={'a': 0, 'b': 0})
344+
all_inputs = []
345345
for v_value in [0.1, -0.2, 1.5, 8.4]:
346346
for a_value in [0.12, 1.5]:
347347
for b_value in [0.82, 1.9]:
348348
y = np.array([v_value])
349349
t = 0
350350
inputs = {'a': a_value, 'b': b_value}
351+
all_inputs.append((t, y, inputs))
352+
for t, y, inputs in all_inputs:
353+
if f == 'casadi':
354+
use_inputs = casadi.vertcat(*[x for x in inputs.values()])
355+
else:
356+
use_inputs = inputs
357+
if f == 'jax':
358+
sens = model.sensitivities_eval(
359+
t, y, use_inputs
360+
)
361+
np.testing.assert_array_equal(
362+
sens['a'],
363+
exact_diff_a(y, inputs['a'], inputs['b'])
364+
)
365+
np.testing.assert_array_equal(
366+
sens['b'],
367+
exact_diff_b(y, inputs['a'], inputs['b'])
368+
)
369+
else:
370+
np.testing.assert_array_equal(
371+
model.sensitivities_eval['a'](
372+
t, y, use_inputs
373+
),
374+
exact_diff_a(y, inputs['a'], inputs['b'])
375+
)
376+
np.testing.assert_array_equal(
377+
model.sensitivities_eval['b'](
378+
t, y, use_inputs
379+
),
380+
exact_diff_b(y, inputs['a'], inputs['b'])
381+
)
351382

352-
self.assertAlmostEqual(
353-
model.sensitivities_eval['a'](
354-
t=0, y=y, inputs=inputs
355-
),
356-
exact_diff_a(v_value, a_value, b_value)
357-
)
358-
self.assertAlmostEqual(
359-
model.sensitivities_eval['b'](
360-
t=0, y=y, inputs=inputs
361-
),
362-
exact_diff_b(v_value, a_value, b_value)
363-
)
364383

365384
if __name__ == "__main__":
366385
print("Add -v for more debug output")

0 commit comments

Comments
 (0)