Skip to content

Commit 03528da

Browse files
committed
#1477 add some tests and remove uncovered lines not neccessary
1 parent 4c7bbe5 commit 03528da

File tree

5 files changed

+53
-34
lines changed

5 files changed

+53
-34
lines changed

pybamm/expression_tree/operations/evaluate_python.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -685,11 +685,7 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
685685
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
686686
result = result.reshape(result.shape[0], -1)
687687

688-
# don't need known_evals, but need to reproduce Symbol.evaluate signature
689-
if known_evals is not None:
690-
return result, known_evals
691-
else:
692-
return result
688+
return result
693689

694690

695691
class EvaluatorJaxSensitivities:
@@ -708,8 +704,4 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
708704
# execute code
709705
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
710706

711-
# don't need known_evals, but need to reproduce Symbol.evaluate signature
712-
if known_evals is not None:
713-
return result, known_evals
714-
else:
715-
return result
707+
return result

pybamm/solvers/base_solver.py

+2-19
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,6 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
220220
if model.calculate_sensitivities and not isinstance(self, pybamm.IDAKLUSolver):
221221
calculate_sensitivities_explicit = True
222222

223-
if calculate_sensitivities_explicit and model.convert_to_format != 'casadi':
224-
raise NotImplementedError(
225-
"Sensitivities only supported for:\n"
226-
" - model.convert_to_format = 'casadi'\n"
227-
" - IDAKLUSolver (any convert_to_format)"
228-
)
229-
230223
# if we are calculating sensitivities explicitly then the number of
231224
# equations will change
232225
if calculate_sensitivities_explicit:
@@ -284,12 +277,7 @@ def report(string):
284277
report(f"Converting {name} to jax")
285278
func = pybamm.EvaluatorJax(func)
286279
jacp = None
287-
if calculate_sensitivities_explicit:
288-
raise NotImplementedError(
289-
"explicit sensitivity equations not supported for "
290-
"convert_to_format='jax'"
291-
)
292-
elif model.calculate_sensitivities:
280+
if model.calculate_sensitivities:
293281
report((
294282
f"Calculating sensitivities for {name} with respect "
295283
f"to parameters {model.calculate_sensitivities} using jax"
@@ -308,12 +296,7 @@ def report(string):
308296
elif model.convert_to_format != "casadi":
309297
# Process with pybamm functions, optionally converting
310298
# to python evaluator
311-
if calculate_sensitivities_explicit:
312-
raise NotImplementedError(
313-
"explicit sensitivity equations not supported for "
314-
"convert_to_format='{}'".format(model.convert_to_format)
315-
)
316-
elif model.calculate_sensitivities:
299+
if model.calculate_sensitivities:
317300
report((
318301
f"Calculating sensitivities for {name} with respect "
319302
f"to parameters {model.calculate_sensitivities}"

pybamm/solvers/solution.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(
9797
elif isinstance(sensitivities, dict):
9898
self._sensitivities = sensitivities
9999
else:
100-
raise RuntimeError('sensitivities arg needs to be a bool or dict')
100+
raise TypeError('sensitivities arg needs to be a bool or dict')
101101

102102
self._t_event = t_event
103103
self._y_event = y_event
@@ -304,10 +304,6 @@ def all_ts(self):
304304
def all_ys(self):
305305
return self._all_ys
306306

307-
@property
308-
def all_ys_and_sens(self):
309-
return self._all_ys_and_sens
310-
311307
@property
312308
def all_models(self):
313309
"""Model(s) used for solution"""

tests/unit/test_solvers/test_processed_variable.py

+42
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,48 @@ def test_processed_variable_0D(self):
5858
)
5959
np.testing.assert_array_equal(processed_var.entries, y_sol[0])
6060

61+
# check empty sensitivity works
62+
63+
def test_processed_variable_0D_no_sensitivity(self):
64+
# without space
65+
t = pybamm.t
66+
y = pybamm.StateVector(slice(0, 1))
67+
var = t * y
68+
var.mesh = None
69+
t_sol = np.linspace(0, 1)
70+
y_sol = np.array([np.linspace(0, 5)])
71+
var_casadi = to_casadi(var, y_sol)
72+
processed_var = pybamm.ProcessedVariable(
73+
[var],
74+
[var_casadi],
75+
pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
76+
warn=False,
77+
)
78+
79+
# test no inputs (i.e. no sensitivity)
80+
self.assertDictEqual(processed_var.sensitivities, {})
81+
82+
# with parameter
83+
t = pybamm.t
84+
y = pybamm.StateVector(slice(0, 1))
85+
a = pybamm.InputParameter('a')
86+
var = t * y * a
87+
var.mesh = None
88+
t_sol = np.linspace(0, 1)
89+
y_sol = np.array([np.linspace(0, 5)])
90+
inputs = {'a': np.array([1.0])}
91+
var_casadi = to_casadi(var, y_sol, inputs=inputs)
92+
processed_var = pybamm.ProcessedVariable(
93+
[var],
94+
[var_casadi],
95+
pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), inputs),
96+
warn=False,
97+
)
98+
99+
# test no sensitivity raises error
100+
with self.assertRaisesRegex(ValueError, 'Cannot compute sensitivities'):
101+
print(processed_var.sensitivities)
102+
61103
def test_processed_variable_1D(self):
62104
t = pybamm.t
63105
var = pybamm.Variable("var", domain=["negative electrode", "separator"])

tests/unit/test_solvers/test_solution.py

+6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ def test_init(self):
2222
self.assertEqual(sol.all_inputs, [{}])
2323
self.assertIsInstance(sol.all_models[0], pybamm.BaseModel)
2424

25+
def test_sensitivities(self):
26+
t = np.linspace(0, 1)
27+
y = np.tile(t, (20, 1))
28+
with self.assertRaises(TypeError):
29+
pybamm.Solution(t, y, pybamm.BaseModel(), {}, sensitivities=1.0)
30+
2531
def test_errors(self):
2632
bad_ts = [np.array([1, 2, 3]), np.array([3, 4, 5])]
2733
sol = pybamm.Solution(

0 commit comments

Comments
 (0)