Skip to content

Commit 7ac7848

Browse files
committed
#858 fix coverage and flake8
1 parent 445ae54 commit 7ac7848

File tree

6 files changed

+26
-16
lines changed

6 files changed

+26
-16
lines changed

pybamm/expression_tree/state_vector.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,13 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
242242
return out
243243

244244
def diff(self, variable):
245-
if variable.id == self.id:
246-
return pybamm.Scalar(1)
247-
elif variable.id == pybamm.t.id:
245+
if variable.id == pybamm.t.id:
248246
return StateVectorDot(*self._y_slices, name=self.name + "'",
249247
domain=self.domain,
250248
auxiliary_domains=self.auxiliary_domains,
251249
evaluation_array=self.evaluation_array)
252250
else:
253-
return pybamm.Scalar(0)
251+
raise NotImplementedError
254252

255253
def _jac(self, variable):
256254
if isinstance(variable, pybamm.StateVector):
@@ -309,14 +307,12 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
309307
return out
310308

311309
def diff(self, variable):
312-
if variable.id == self.id:
313-
return pybamm.Scalar(1)
314-
elif variable.id == pybamm.t.id:
310+
if variable.id == pybamm.t.id:
315311
raise pybamm.ModelError(
316312
"cannot take second time derivative of a state vector"
317313
)
318314
else:
319-
return pybamm.Scalar(0)
315+
raise NotImplementedError
320316

321317
def _jac(self, variable):
322318
if isinstance(variable, pybamm.StateVectorDot):

tests/unit/test_discretisations/test_discretisation.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,8 @@ def test_process_model_ode(self):
703703
# test that any time derivatives of variables in rhs raises an
704704
# error
705705
model = pybamm.BaseModel()
706-
model.rhs = {c: pybamm.div(N) + c.diff(pybamm.t), T: pybamm.div(q), S: pybamm.div(p)}
706+
model.rhs = {c: pybamm.div(N) + c.diff(pybamm.t),
707+
T: pybamm.div(q), S: pybamm.div(p)}
707708
model.initial_conditions = {
708709
c: pybamm.Scalar(2),
709710
T: pybamm.Scalar(5),
@@ -846,8 +847,6 @@ def test_process_model_dae(self):
846847
with self.assertRaises(pybamm.ModelError):
847848
disc.process_model(model)
848849

849-
850-
851850
def test_process_model_concatenation(self):
852851
# concatenation of variables as the key
853852
cn = pybamm.Variable("c", domain=["negative electrode"])
@@ -1144,6 +1143,7 @@ def test_mass_matirx_inverse(self):
11441143
model.mass_matrix_inv.entries.toarray(), mass_inv.toarray()
11451144
)
11461145

1146+
11471147
if __name__ == "__main__":
11481148
print("Add -v for more debug output")
11491149
import sys

tests/unit/test_expression_tree/test_d_dt.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def test_time_derivative(self):
2222
self.assertEqual(a.simplify().id, (2 * pybamm.t).id)
2323
self.assertEqual(a.evaluate(t=1), 2)
2424

25-
a =(2 + pybamm.t**2).diff(pybamm.t)
26-
self.assertEqual(a.simplify().id, (2*pybamm.t).id)
25+
a = (2 + pybamm.t**2).diff(pybamm.t)
26+
self.assertEqual(a.simplify().id, (2 * pybamm.t).id)
2727
self.assertEqual(a.evaluate(t=1), 2)
2828

2929
def test_time_derivative_of_variable(self):
@@ -33,7 +33,7 @@ def test_time_derivative_of_variable(self):
3333
self.assertEqual(a.name, "a'")
3434

3535
p = pybamm.Parameter('p')
36-
a = (1 + p*pybamm.Variable('a')).diff(pybamm.t).simplify()
36+
a = (1 + p * pybamm.Variable('a')).diff(pybamm.t).simplify()
3737
self.assertIsInstance(a, pybamm.Multiplication)
3838
self.assertEqual(a.children[0].name, 'p')
3939
self.assertEqual(a.children[1].name, "a'")
@@ -59,6 +59,7 @@ def test_time_derivative_of_state_vector(self):
5959
with self.assertRaises(pybamm.ModelError):
6060
a = (sv).diff(pybamm.t).diff(pybamm.t)
6161

62+
6263
if __name__ == "__main__":
6364
print("Add -v for more debug output")
6465
import sys

tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ def test_convert_external_variable(self):
201201

202202
# External only
203203
self.assert_casadi_equal(
204-
pybamm_u1.to_casadi(casadi_t, casadi_y, u=casadi_us), casadi_us["External 1"]
204+
pybamm_u1.to_casadi(casadi_t, casadi_y, u=casadi_us),
205+
casadi_us["External 1"]
205206
)
206207

207208
# More complex

tests/unit/test_expression_tree/test_state_vector.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ def test_evaluate_list(self):
3535
y = np.linspace(0, 3, 31)
3636
np.testing.assert_array_almost_equal(sv.evaluate(y=y), y[:, np.newaxis])
3737

38+
def test_diff(self):
39+
a = pybamm.StateVector(slice(0, 10))
40+
with self.assertRaises(NotImplementedError):
41+
a.diff(a)
42+
b = pybamm.StateVectorDot(slice(0, 10))
43+
with self.assertRaises(NotImplementedError):
44+
a.diff(b)
45+
3846
def test_name(self):
3947
sv = pybamm.StateVector(slice(0, 10))
4048
self.assertEqual(sv.name, "y[0:10]")
@@ -61,6 +69,7 @@ def test_failure(self):
6169
with self.assertRaisesRegex(TypeError, "all y_slices must be slice objects"):
6270
pybamm.StateVector(slice(0, 10), 1)
6371

72+
6473
class TestStateVectorDot(unittest.TestCase):
6574
def test_evaluate(self):
6675
sv = pybamm.StateVectorDot(slice(0, 10))
@@ -72,14 +81,16 @@ def test_evaluate(self):
7281
# Try evaluating with a y that is too short
7382
y_dot2 = np.ones(5)
7483
with self.assertRaisesRegex(
75-
ValueError, "y_dot is too short, so value with slice is smaller than expected"
84+
ValueError,
85+
"y_dot is too short, so value with slice is smaller than expected"
7686
):
7787
sv.evaluate(y_dot=y_dot2)
7888

7989
def test_name(self):
8090
sv = pybamm.StateVectorDot(slice(0, 10))
8191
self.assertEqual(sv.name, "y_dot[0:10]")
8292

93+
8394
if __name__ == "__main__":
8495
print("Add -v for more debug output")
8596
import sys

tests/unit/test_parameters/test_parameters_cli.py

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def test_list_params(self):
120120
# but must not intefere with existing input dir if it exists
121121
# in the current dir...
122122

123+
123124
if __name__ == "__main__":
124125
print("Add -v for more debug output")
125126
import sys

0 commit comments

Comments
 (0)