Skip to content

Commit 345be6a

Browse files
committed
#858 fix test errors
1 parent 7ac7848 commit 345be6a

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

pybamm/expression_tree/state_vector.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -242,13 +242,15 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
242242
return out
243243

244244
def diff(self, variable):
245+
if variable.id == self.id:
246+
return pybamm.Scalar(1)
245247
if variable.id == pybamm.t.id:
246248
return StateVectorDot(*self._y_slices, name=self.name + "'",
247249
domain=self.domain,
248250
auxiliary_domains=self.auxiliary_domains,
249251
evaluation_array=self.evaluation_array)
250252
else:
251-
raise NotImplementedError
253+
return pybamm.Scalar(0)
252254

253255
def _jac(self, variable):
254256
if isinstance(variable, pybamm.StateVector):
@@ -307,12 +309,14 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
307309
return out
308310

309311
def diff(self, variable):
310-
if variable.id == pybamm.t.id:
312+
if variable.id == self.id:
313+
return pybamm.Scalar(1)
314+
elif variable.id == pybamm.t.id:
311315
raise pybamm.ModelError(
312316
"cannot take second time derivative of a state vector"
313317
)
314318
else:
315-
raise NotImplementedError
319+
return pybamm.Scalar(0)
316320

317321
def _jac(self, variable):
318322
if isinstance(variable, pybamm.StateVectorDot):

tests/unit/test_expression_tree/test_state_vector.py

-8
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,6 @@ def test_evaluate_list(self):
3535
y = np.linspace(0, 3, 31)
3636
np.testing.assert_array_almost_equal(sv.evaluate(y=y), y[:, np.newaxis])
3737

38-
def test_diff(self):
39-
a = pybamm.StateVector(slice(0, 10))
40-
with self.assertRaises(NotImplementedError):
41-
a.diff(a)
42-
b = pybamm.StateVectorDot(slice(0, 10))
43-
with self.assertRaises(NotImplementedError):
44-
a.diff(b)
45-
4638
def test_name(self):
4739
sv = pybamm.StateVector(slice(0, 10))
4840
self.assertEqual(sv.name, "y[0:10]")

0 commit comments

Comments
 (0)