Skip to content

Commit ce574ed

Browse files
#709 fix unit tests
1 parent 07d6c34 commit ce574ed

File tree

9 files changed

+78
-80
lines changed

9 files changed

+78
-80
lines changed

pybamm/expression_tree/functions.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def _function_diff(self, children, idx):
260260

261261
def arcsinh(child):
262262
" Returns arcsinh function of child. "
263-
return Arcsinh(child)
263+
return pybamm.simplify_if_constant(Arcsinh(child), keep_domains=True)
264264

265265

266266
class Cos(SpecificFunction):
@@ -276,7 +276,7 @@ def _function_diff(self, children, idx):
276276

277277
def cos(child):
278278
" Returns cosine function of child. "
279-
return Cos(child)
279+
return pybamm.simplify_if_constant(Cos(child), keep_domains=True)
280280

281281

282282
class Cosh(SpecificFunction):
@@ -292,7 +292,7 @@ def _function_diff(self, children, idx):
292292

293293
def cosh(child):
294294
" Returns hyperbolic cosine function of child. "
295-
return Cosh(child)
295+
return pybamm.simplify_if_constant(Cosh(child), keep_domains=True)
296296

297297

298298
class Exponential(SpecificFunction):
@@ -308,7 +308,7 @@ def _function_diff(self, children, idx):
308308

309309
def exp(child):
310310
" Returns exponential function of child. "
311-
return Exponential(child)
311+
return pybamm.simplify_if_constant(Exponential(child), keep_domains=True)
312312

313313

314314
class Log(SpecificFunction):
@@ -330,7 +330,7 @@ def _function_diff(self, children, idx):
330330
def log(child, base="e"):
331331
" Returns logarithmic function of child (any base, default 'e'). "
332332
if base == "e":
333-
return Log(child)
333+
return pybamm.simplify_if_constant(Log(child), keep_domains=True)
334334
else:
335335
return Log(child) / np.log(base)
336336

@@ -342,17 +342,17 @@ def log10(child):
342342

343343
def max(child):
344344
" Returns max function of child. "
345-
return Function(np.max, child)
345+
return pybamm.simplify_if_constant(Function(np.max, child), keep_domains=True)
346346

347347

348348
def min(child):
349349
" Returns min function of child. "
350-
return Function(np.min, child)
350+
return pybamm.simplify_if_constant(Function(np.min, child), keep_domains=True)
351351

352352

353353
def sech(child):
354354
" Returns hyperbolic sec function of child. "
355-
return 1 / Cosh(child)
355+
return pybamm.simplify_if_constant(1 / Cosh(child), keep_domains=True)
356356

357357

358358
class Sin(SpecificFunction):
@@ -368,7 +368,7 @@ def _function_diff(self, children, idx):
368368

369369
def sin(child):
370370
" Returns sine function of child. "
371-
return Sin(child)
371+
return pybamm.simplify_if_constant(Sin(child), keep_domains=True)
372372

373373

374374
class Sinh(SpecificFunction):
@@ -384,7 +384,7 @@ def _function_diff(self, children, idx):
384384

385385
def sinh(child):
386386
" Returns hyperbolic sine function of child. "
387-
return Sinh(child)
387+
return pybamm.simplify_if_constant(Sinh(child), keep_domains=True)
388388

389389

390390
class Sqrt(SpecificFunction):
@@ -405,7 +405,7 @@ def _function_diff(self, children, idx):
405405

406406
def sqrt(child):
407407
" Returns square root function of child. "
408-
return Sqrt(child)
408+
return pybamm.simplify_if_constant(Sqrt(child), keep_domains=True)
409409

410410

411411
class Tanh(SpecificFunction):
@@ -421,4 +421,4 @@ def _function_diff(self, children, idx):
421421

422422
def tanh(child):
423423
" Returns hyperbolic tan function of child. "
424-
return Tanh(child)
424+
return pybamm.simplify_if_constant(Tanh(child), keep_domains=True)

pybamm/expression_tree/operations/simplify.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
import numbers
8-
from scipy.sparse import issparse
8+
from scipy.sparse import issparse, csr_matrix
99

1010

1111
def simplify_if_constant(symbol, keep_domains=False):
@@ -32,6 +32,9 @@ def simplify_if_constant(symbol, keep_domains=False):
3232
result, domain=domain, auxiliary_domains=auxiliary_domains
3333
)
3434
else:
35+
# Turn matrix of zeros into sparse matrix
36+
if isinstance(result, np.ndarray) and np.all(result == 0):
37+
result = csr_matrix(result)
3538
return pybamm.Matrix(
3639
result, domain=domain, auxiliary_domains=auxiliary_domains
3740
)

tests/unit/test_expression_tree/test_functions.py

+56-44
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_multi_var_function_cube(arg1, arg2):
2323
class TestFunction(unittest.TestCase):
2424
def test_number_input(self):
2525
# with numbers
26-
log = pybamm.log(10)
26+
log = pybamm.Function(np.log, 10)
2727
self.assertIsInstance(log.children[0], pybamm.Scalar)
2828
self.assertEqual(log.evaluate(), np.log(10))
2929

@@ -127,27 +127,29 @@ def test_function_unnamed(self):
127127

128128
class TestSpecificFunctions(unittest.TestCase):
129129
def test_arcsinh(self):
130-
a = pybamm.Scalar(3)
130+
a = pybamm.InputParameter("a")
131131
fun = pybamm.arcsinh(a)
132132
self.assertIsInstance(fun, pybamm.Arcsinh)
133-
self.assertEqual(fun.evaluate(), np.arcsinh(3))
133+
self.assertEqual(fun.evaluate(u={"a": 3}), np.arcsinh(3))
134134
h = 0.0000001
135135
self.assertAlmostEqual(
136-
fun.diff(a).evaluate(),
137-
(pybamm.arcsinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
136+
fun.diff(a).evaluate(u={"a": 3}),
137+
(pybamm.arcsinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
138+
/ h,
138139
places=5,
139140
)
140141

141142
def test_cos(self):
142-
a = pybamm.Scalar(3)
143+
a = pybamm.InputParameter("a")
143144
fun = pybamm.cos(a)
144145
self.assertIsInstance(fun, pybamm.Cos)
145146
self.assertEqual(fun.children[0].id, a.id)
146-
self.assertEqual(fun.evaluate(), np.cos(3))
147+
self.assertEqual(fun.evaluate(u={"a": 3}), np.cos(3))
147148
h = 0.0000001
148149
self.assertAlmostEqual(
149-
fun.diff(a).evaluate(),
150-
(pybamm.cos(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
150+
fun.diff(a).evaluate(u={"a": 3}),
151+
(pybamm.cos(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
152+
/ h,
151153
places=5,
152154
)
153155

@@ -157,110 +159,120 @@ def test_cos(self):
157159
self.assertEqual(fun.id, fun.simplify().id)
158160

159161
def test_cosh(self):
160-
a = pybamm.Scalar(3)
162+
a = pybamm.InputParameter("a")
161163
fun = pybamm.cosh(a)
162164
self.assertIsInstance(fun, pybamm.Cosh)
163165
self.assertEqual(fun.children[0].id, a.id)
164-
self.assertEqual(fun.evaluate(), np.cosh(3))
166+
self.assertEqual(fun.evaluate(u={"a": 3}), np.cosh(3))
165167
h = 0.0000001
166168
self.assertAlmostEqual(
167-
fun.diff(a).evaluate(),
168-
(pybamm.cosh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
169+
fun.diff(a).evaluate(u={"a": 3}),
170+
(pybamm.cosh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
171+
/ h,
169172
places=5,
170173
)
171174

172175
def test_exp(self):
173-
a = pybamm.Scalar(3)
176+
a = pybamm.InputParameter("a")
174177
fun = pybamm.exp(a)
175178
self.assertIsInstance(fun, pybamm.Exponential)
176179
self.assertEqual(fun.children[0].id, a.id)
177-
self.assertEqual(fun.evaluate(), np.exp(3))
180+
self.assertEqual(fun.evaluate(u={"a": 3}), np.exp(3))
178181
h = 0.0000001
179182
self.assertAlmostEqual(
180-
fun.diff(a).evaluate(),
181-
(pybamm.exp(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
183+
fun.diff(a).evaluate(u={"a": 3}),
184+
(pybamm.exp(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
185+
/ h,
182186
places=5,
183187
)
184188

185189
def test_log(self):
186-
a = pybamm.Scalar(3)
190+
a = pybamm.InputParameter("a")
187191
fun = pybamm.log(a)
188-
self.assertEqual(fun.evaluate(), np.log(3))
192+
self.assertEqual(fun.evaluate(u={"a": 3}), np.log(3))
189193
h = 0.0000001
190194
self.assertAlmostEqual(
191-
fun.diff(a).evaluate(),
192-
(pybamm.log(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
195+
fun.diff(a).evaluate(u={"a": 3}),
196+
(pybamm.log(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
197+
/ h,
193198
places=5,
194199
)
195200

196201
# Base 10
197202
fun = pybamm.log10(a)
198-
self.assertEqual(fun.evaluate(), np.log10(3))
203+
self.assertEqual(fun.evaluate(u={"a": 3}), np.log10(3))
199204
h = 0.0000001
200205
self.assertAlmostEqual(
201-
fun.diff(a).evaluate(),
202-
(pybamm.log10(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
206+
fun.diff(a).evaluate(u={"a": 3}),
207+
(pybamm.log10(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
208+
/ h,
203209
places=5,
204210
)
205211

206212
def test_max(self):
207-
a = pybamm.Vector(np.array([1, 2, 3]))
213+
a = pybamm.StateVector(slice(0, 3))
214+
y_test = np.array([1, 2, 3])
208215
fun = pybamm.max(a)
209216
self.assertIsInstance(fun, pybamm.Function)
210-
self.assertEqual(fun.evaluate(), 3)
217+
self.assertEqual(fun.evaluate(y=y_test), 3)
211218

212219
def test_min(self):
213-
a = pybamm.Vector(np.array([1, 2, 3]))
220+
a = pybamm.StateVector(slice(0, 3))
221+
y_test = np.array([1, 2, 3])
214222
fun = pybamm.min(a)
215223
self.assertIsInstance(fun, pybamm.Function)
216-
self.assertEqual(fun.evaluate(), 1)
224+
self.assertEqual(fun.evaluate(y=y_test), 1)
217225

218226
def test_sin(self):
219-
a = pybamm.Scalar(3)
227+
a = pybamm.InputParameter("a")
220228
fun = pybamm.sin(a)
221229
self.assertIsInstance(fun, pybamm.Sin)
222230
self.assertEqual(fun.children[0].id, a.id)
223-
self.assertEqual(fun.evaluate(), np.sin(3))
231+
self.assertEqual(fun.evaluate(u={"a": 3}), np.sin(3))
224232
h = 0.0000001
225233
self.assertAlmostEqual(
226-
fun.diff(a).evaluate(),
227-
(pybamm.sin(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
234+
fun.diff(a).evaluate(u={"a": 3}),
235+
(pybamm.sin(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
236+
/ h,
228237
places=5,
229238
)
230239

231240
def test_sinh(self):
232-
a = pybamm.Scalar(3)
241+
a = pybamm.InputParameter("a")
233242
fun = pybamm.sinh(a)
234243
self.assertIsInstance(fun, pybamm.Sinh)
235244
self.assertEqual(fun.children[0].id, a.id)
236-
self.assertEqual(fun.evaluate(), np.sinh(3))
245+
self.assertEqual(fun.evaluate(u={"a": 3}), np.sinh(3))
237246
h = 0.0000001
238247
self.assertAlmostEqual(
239-
fun.diff(a).evaluate(),
240-
(pybamm.sinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
248+
fun.diff(a).evaluate(u={"a": 3}),
249+
(pybamm.sinh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
250+
/ h,
241251
places=5,
242252
)
243253

244254
def test_sqrt(self):
245-
a = pybamm.Scalar(3)
255+
a = pybamm.InputParameter("a")
246256
fun = pybamm.sqrt(a)
247257
self.assertIsInstance(fun, pybamm.Sqrt)
248-
self.assertEqual(fun.evaluate(), np.sqrt(3))
258+
self.assertEqual(fun.evaluate(u={"a": 3}), np.sqrt(3))
249259
h = 0.0000001
250260
self.assertAlmostEqual(
251-
fun.diff(a).evaluate(),
252-
(pybamm.sqrt(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
261+
fun.diff(a).evaluate(u={"a": 3}),
262+
(pybamm.sqrt(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
263+
/ h,
253264
places=5,
254265
)
255266

256267
def test_tanh(self):
257-
a = pybamm.Scalar(3)
268+
a = pybamm.InputParameter("a")
258269
fun = pybamm.tanh(a)
259-
self.assertEqual(fun.evaluate(), np.tanh(3))
270+
self.assertEqual(fun.evaluate(u={"a": 3}), np.tanh(3))
260271
h = 0.0000001
261272
self.assertAlmostEqual(
262-
fun.diff(a).evaluate(),
263-
(pybamm.tanh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate()) / h,
273+
fun.diff(a).evaluate(u={"a": 3}),
274+
(pybamm.tanh(pybamm.Scalar(3 + h)).evaluate() - fun.evaluate(u={"a": 3}))
275+
/ h,
264276
places=5,
265277
)
266278

tests/unit/test_expression_tree/test_matrix.py

-7
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,6 @@ def test_matrix_operations(self):
2929
(self.mat @ self.vect).evaluate(), np.array([[5], [2], [3]])
3030
)
3131

32-
def test_matrix_modification(self):
33-
exp = self.mat @ self.mat + self.mat
34-
self.A[0, 0] = -1
35-
self.assertTrue(exp.children[1]._entries[0, 0], -1)
36-
self.assertTrue(exp.children[0].children[0]._entries[0, 0], -1)
37-
self.assertTrue(exp.children[0].children[1]._entries[0, 0], -1)
38-
3932

4033
class TestArray(unittest.TestCase):
4134
def test_name(self):

tests/unit/test_expression_tree/test_operations/test_simplify.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def test_matrix_divide_simplify(self):
518518
expr3 = (m / a).simplify()
519519
self.assertIsInstance(expr3, pybamm.Matrix)
520520
self.assertEqual(expr3.shape, m.shape)
521-
np.testing.assert_array_equal(expr3.evaluate(), np.zeros((10, 10)))
521+
np.testing.assert_array_equal(expr3.evaluate().toarray(), np.zeros((10, 10)))
522522

523523
def test_domain_concatenation_simplify(self):
524524
# create discretisation

tests/unit/test_expression_tree/test_vector.py

-7
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,6 @@ def test_vector_operations(self):
3131
(self.vect * self.vect).evaluate(), np.array([[1], [4], [9]])
3232
)
3333

34-
def test_vector_modification(self):
35-
exp = self.vect * self.vect + self.vect
36-
self.x[0] = -1
37-
self.assertTrue(exp.children[1]._entries[0], -1)
38-
self.assertTrue(exp.children[0].children[0]._entries[0], -1)
39-
self.assertTrue(exp.children[0].children[1]._entries[0], -1)
40-
4134
def test_wrong_size_entries(self):
4235
with self.assertRaisesRegex(
4336
ValueError, "Entries must have 1 dimension or be column vector"

tests/unit/test_parameters/test_geometric_parameters.py

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def test_macroscale_parameters(self):
3131
self.assertEqual(
3232
(L_n_eval + L_s_eval + L_p_eval).evaluate(), L_x_eval.evaluate()
3333
)
34-
self.assertEqual((L_n_eval + L_s_eval + L_p_eval).id, L_x_eval.id)
3534
l_n_eval = parameter_values.process_symbol(l_n)
3635
l_s_eval = parameter_values.process_symbol(l_s)
3736
l_p_eval = parameter_values.process_symbol(l_p)

tests/unit/test_parameters/test_parameter_values.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -279,25 +279,23 @@ def test_process_function_parameter(self):
279279
"const": 254,
280280
}
281281
)
282-
a = pybamm.Parameter("a")
282+
a = pybamm.InputParameter("a")
283283

284284
# process function
285285
func = pybamm.FunctionParameter("func", a)
286286
processed_func = parameter_values.process_symbol(func)
287-
self.assertEqual(processed_func.evaluate(), 369)
287+
self.assertEqual(processed_func.evaluate(u={"a": 3}), 369)
288288

289289
# process constant function
290290
const = pybamm.FunctionParameter("const", a)
291291
processed_const = parameter_values.process_symbol(const)
292-
self.assertIsInstance(processed_const, pybamm.Multiplication)
293-
self.assertIsInstance(processed_const.left, pybamm.Scalar)
294-
self.assertIsInstance(processed_const.right, pybamm.Scalar)
292+
self.assertIsInstance(processed_const, pybamm.Scalar)
295293
self.assertEqual(processed_const.evaluate(), 254)
296294

297295
# process differentiated function parameter
298296
diff_func = func.diff(a)
299297
processed_diff_func = parameter_values.process_symbol(diff_func)
300-
self.assertEqual(processed_diff_func.evaluate(), 123)
298+
self.assertEqual(processed_diff_func.evaluate(u={"a": 3}), 123)
301299

302300
def test_process_inline_function_parameters(self):
303301
def D(c):

0 commit comments

Comments
 (0)