Skip to content

Commit f6b0ca9

Browse files
#709 first go at simplifying on creation
1 parent d1344e3 commit f6b0ca9

File tree

7 files changed

+106
-52
lines changed

7 files changed

+106
-52
lines changed

pybamm/expression_tree/operations/simplify.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@
88
from scipy.sparse import issparse
99

1010

11-
def simplify_if_constant(symbol):
11+
def simplify_if_constant(symbol, keep_domains=False):
1212
"""
1313
Utility function to simplify an expression tree if it evalutes to a constant
1414
scalar, vector or matrix
1515
"""
16+
if keep_domains is True:
17+
domain = symbol.domain
18+
auxiliary_domains = symbol.auxiliary_domains
19+
else:
20+
domain = None
21+
auxiliary_domains = None
1622
if symbol.is_constant():
1723
result = symbol.evaluate_ignoring_errors()
1824
if result is not None:
@@ -22,9 +28,13 @@ def simplify_if_constant(symbol):
2228
return pybamm.Scalar(result)
2329
elif isinstance(result, np.ndarray) or issparse(result):
2430
if result.ndim == 1 or result.shape[1] == 1:
25-
return pybamm.Vector(result)
31+
return pybamm.Vector(
32+
result, domain=domain, auxiliary_domains=auxiliary_domains
33+
)
2634
else:
27-
return pybamm.Matrix(result)
35+
return pybamm.Matrix(
36+
result, domain=domain, auxiliary_domains=auxiliary_domains
37+
)
2838

2939
return symbol
3040

pybamm/expression_tree/symbol.py

+49-19
Original file line numberDiff line numberDiff line change
@@ -360,79 +360,109 @@ def __repr__(self):
360360

361361
def __add__(self, other):
362362
"""return an :class:`Addition` object"""
363-
return pybamm.Addition(self, other)
363+
return pybamm.simplify_if_constant(
364+
pybamm.Addition(self, other), keep_domains=True
365+
)
364366

365367
def __radd__(self, other):
366368
"""return an :class:`Addition` object"""
367-
return pybamm.Addition(other, self)
369+
return pybamm.simplify_if_constant(
370+
pybamm.Addition(other, self), keep_domains=True
371+
)
368372

369373
def __sub__(self, other):
370374
"""return a :class:`Subtraction` object"""
371-
return pybamm.Subtraction(self, other)
375+
return pybamm.simplify_if_constant(
376+
pybamm.Subtraction(self, other), keep_domains=True
377+
)
372378

373379
def __rsub__(self, other):
374380
"""return a :class:`Subtraction` object"""
375-
return pybamm.Subtraction(other, self)
381+
return pybamm.simplify_if_constant(
382+
pybamm.Subtraction(other, self), keep_domains=True
383+
)
376384

377385
def __mul__(self, other):
378386
"""return a :class:`Multiplication` object"""
379-
return pybamm.Multiplication(self, other)
387+
return pybamm.simplify_if_constant(
388+
pybamm.Multiplication(self, other), keep_domains=True
389+
)
380390

381391
def __rmul__(self, other):
382392
"""return a :class:`Multiplication` object"""
383-
return pybamm.Multiplication(other, self)
393+
return pybamm.simplify_if_constant(
394+
pybamm.Multiplication(other, self), keep_domains=True
395+
)
384396

385397
def __matmul__(self, other):
386398
"""return a :class:`MatrixMultiplication` object"""
387-
return pybamm.MatrixMultiplication(self, other)
399+
return pybamm.simplify_if_constant(
400+
pybamm.MatrixMultiplication(self, other), keep_domains=True
401+
)
388402

389403
def __rmatmul__(self, other):
390404
"""return a :class:`MatrixMultiplication` object"""
391-
return pybamm.MatrixMultiplication(other, self)
405+
return pybamm.simplify_if_constant(
406+
pybamm.MatrixMultiplication(other, self), keep_domains=True
407+
)
392408

393409
def __truediv__(self, other):
394410
"""return a :class:`Division` object"""
395-
return pybamm.Division(self, other)
411+
return pybamm.simplify_if_constant(
412+
pybamm.Division(self, other), keep_domains=True
413+
)
396414

397415
def __rtruediv__(self, other):
398416
"""return a :class:`Division` object"""
399-
return pybamm.Division(other, self)
417+
return pybamm.simplify_if_constant(
418+
pybamm.Division(other, self), keep_domains=True
419+
)
400420

401421
def __pow__(self, other):
402422
"""return a :class:`Power` object"""
403-
return pybamm.Power(self, other)
423+
return pybamm.simplify_if_constant(pybamm.Power(self, other), keep_domains=True)
404424

405425
def __rpow__(self, other):
406426
"""return a :class:`Power` object"""
407-
return pybamm.Power(other, self)
427+
return pybamm.simplify_if_constant(pybamm.Power(other, self), keep_domains=True)
408428

409429
def __lt__(self, other):
410430
"""return a :class:`Heaviside` object"""
411-
return pybamm.Heaviside(self, other, equal=False)
431+
return pybamm.simplify_if_constant(
432+
pybamm.Heaviside(self, other, equal=False), keep_domains=True
433+
)
412434

413435
def __le__(self, other):
414436
"""return a :class:`Heaviside` object"""
415-
return pybamm.Heaviside(self, other, equal=True)
437+
return pybamm.simplify_if_constant(
438+
pybamm.Heaviside(self, other, equal=True), keep_domains=True
439+
)
416440

417441
def __gt__(self, other):
418442
"""return a :class:`Heaviside` object"""
419-
return pybamm.Heaviside(other, self, equal=False)
443+
return pybamm.simplify_if_constant(
444+
pybamm.Heaviside(other, self, equal=False), keep_domains=True
445+
)
420446

421447
def __ge__(self, other):
422448
"""return a :class:`Heaviside` object"""
423-
return pybamm.Heaviside(other, self, equal=True)
449+
return pybamm.simplify_if_constant(
450+
pybamm.Heaviside(other, self, equal=True), keep_domains=True
451+
)
424452

425453
def __neg__(self):
426454
"""return a :class:`Negate` object"""
427-
return pybamm.Negate(self)
455+
return pybamm.simplify_if_constant(pybamm.Negate(self), keep_domains=True)
428456

429457
def __abs__(self):
430458
"""return an :class:`AbsoluteValue` object"""
431-
return pybamm.AbsoluteValue(self)
459+
return pybamm.simplify_if_constant(
460+
pybamm.AbsoluteValue(self), keep_domains=True
461+
)
432462

433463
def __getitem__(self, key):
434464
"""return a :class:`Index` object"""
435-
return pybamm.Index(self, key)
465+
return pybamm.simplify_if_constant(pybamm.Index(self, key), keep_domains=True)
436466

437467
def diff(self, variable):
438468
"""

pybamm/spatial_methods/finite_volume.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,9 @@ def process_binary_operators(self, bin_op, left, right, disc_left, disc_right):
10251025
method = "arithmetic"
10261026
disc_left = self.node_to_edge(disc_left, method=method)
10271027
# Return new binary operator with appropriate class
1028-
out = bin_op.__class__(disc_left, disc_right)
1028+
out = pybamm.simplify_if_constant(
1029+
bin_op.__class__(disc_left, disc_right), keep_domains=True
1030+
)
10291031
return out
10301032

10311033
def concatenation(self, disc_children):

tests/unit/test_discretisations/test_discretisation.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,10 @@ def test_process_symbol_base(self):
251251
self.assertIsInstance(un1_disc, pybamm.Negate)
252252
self.assertIsInstance(un1_disc.children[0], pybamm.StateVector)
253253

254-
un2 = abs(scal)
254+
un2 = abs(var)
255255
un2_disc = disc.process_symbol(un2)
256256
self.assertIsInstance(un2_disc, pybamm.AbsoluteValue)
257-
self.assertIsInstance(un2_disc.children[0], pybamm.Scalar)
257+
self.assertIsInstance(un2_disc.children[0], pybamm.StateVector)
258258

259259
# function of one variable
260260
def myfun(x):
@@ -749,7 +749,7 @@ def test_process_empty_model(self):
749749
def test_broadcast(self):
750750
whole_cell = ["negative electrode", "separator", "positive electrode"]
751751

752-
a = pybamm.Scalar(7)
752+
a = pybamm.InputParameter("a")
753753
var = pybamm.Variable("var")
754754

755755
# create discretisation
@@ -761,13 +761,14 @@ def test_broadcast(self):
761761
# scalar
762762
broad = disc.process_symbol(pybamm.FullBroadcast(a, whole_cell, {}))
763763
np.testing.assert_array_equal(
764-
broad.evaluate(), 7 * np.ones_like(combined_submesh[0].nodes[:, np.newaxis])
764+
broad.evaluate(u={"a": 7}),
765+
7 * np.ones_like(combined_submesh[0].nodes[:, np.newaxis]),
765766
)
766767
self.assertEqual(broad.domain, whole_cell)
767768

768769
broad_disc = disc.process_symbol(broad)
769770
self.assertIsInstance(broad_disc, pybamm.Multiplication)
770-
self.assertIsInstance(broad_disc.children[0], pybamm.Scalar)
771+
self.assertIsInstance(broad_disc.children[0], pybamm.InputParameter)
771772
self.assertIsInstance(broad_disc.children[1], pybamm.Vector)
772773

773774
# process Broadcast variable
@@ -804,16 +805,14 @@ def test_secondary_broadcast_2D(self):
804805
# secondary broadcast in 2D --> Matrix multiplication
805806
disc = get_discretisation_for_testing()
806807
mesh = disc.mesh
807-
var = pybamm.Vector(
808-
mesh["negative particle"][0].nodes, domain=["negative particle"]
809-
)
808+
var = pybamm.Variable("var", domain=["negative particle"])
810809
broad = pybamm.SecondaryBroadcast(var, "negative electrode")
811810

812811
disc.set_variable_slices([var])
813812
broad_disc = disc.process_symbol(broad)
814813
self.assertIsInstance(broad_disc, pybamm.MatrixMultiplication)
815814
self.assertIsInstance(broad_disc.children[0], pybamm.Matrix)
816-
self.assertIsInstance(broad_disc.children[1], pybamm.Vector)
815+
self.assertIsInstance(broad_disc.children[1], pybamm.StateVector)
817816
self.assertEqual(
818817
broad_disc.shape,
819818
(mesh["negative particle"][0].npts * mesh["negative electrode"][0].npts, 1),
@@ -905,7 +904,9 @@ def test_exceptions(self):
905904

906905
# check doesn't raise if broadcast
907906
model.variables = {
908-
c_n.name: pybamm.PrimaryBroadcast(pybamm.Scalar(2), ["negative electrode"])
907+
c_n.name: pybamm.PrimaryBroadcast(
908+
pybamm.InputParameter("a"), ["negative electrode"]
909+
)
909910
}
910911
disc.process_model(model)
911912

tests/unit/test_expression_tree/test_binary_operators.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def test_addition(self):
4545
self.assertEqual(summ.children[0].name, a.name)
4646
self.assertEqual(summ.children[1].name, b.name)
4747

48+
# test simplifying
49+
summ2 = pybamm.Scalar(1) + pybamm.Scalar(3)
50+
self.assertEqual(summ2.id, pybamm.Scalar(4).id)
51+
4852
def test_power(self):
4953
a = pybamm.Symbol("a")
5054
b = pybamm.Symbol("b")
@@ -61,22 +65,28 @@ def test_power(self):
6165
def test_known_eval(self):
6266
# Scalars
6367
a = pybamm.Scalar(4)
64-
b = pybamm.Scalar(2)
68+
b = pybamm.StateVector(slice(0, 1))
6569
expr = (a + b) - (a + b) * (a + b)
66-
value = expr.evaluate()
67-
self.assertEqual(expr.evaluate(known_evals={})[0], value)
68-
self.assertIn((a + b).id, expr.evaluate(known_evals={})[1])
69-
self.assertEqual(expr.evaluate(known_evals={})[1][(a + b).id], 6)
70+
value = expr.evaluate(y=np.array([2]))
71+
self.assertEqual(expr.evaluate(y=np.array([2]), known_evals={})[0], value)
72+
self.assertIn((a + b).id, expr.evaluate(y=np.array([2]), known_evals={})[1])
73+
self.assertEqual(
74+
expr.evaluate(y=np.array([2]), known_evals={})[1][(a + b).id], 6
75+
)
7076

7177
# Matrices
7278
a = pybamm.Matrix(np.random.rand(5, 5))
73-
b = pybamm.Matrix(np.random.rand(5, 5))
79+
b = pybamm.StateVector(slice(0, 5))
7480
expr2 = (a @ b) - (a @ b) * (a @ b) + (a @ b)
75-
value = expr2.evaluate()
76-
np.testing.assert_array_equal(expr2.evaluate(known_evals={})[0], value)
77-
self.assertIn((a @ b).id, expr2.evaluate(known_evals={})[1])
81+
y_test = np.linspace(0, 1, 5)
82+
value = expr2.evaluate(y=y_test)
83+
np.testing.assert_array_equal(
84+
expr2.evaluate(y=y_test, known_evals={})[0], value
85+
)
86+
self.assertIn((a @ b).id, expr2.evaluate(y=y_test, known_evals={})[1])
7887
np.testing.assert_array_equal(
79-
expr2.evaluate(known_evals={})[1][(a @ b).id], (a @ b).evaluate()
88+
expr2.evaluate(y=y_test, known_evals={})[1][(a @ b).id],
89+
(a @ b).evaluate(y=y_test),
8090
)
8191

8292
def test_diff(self):
@@ -158,7 +168,7 @@ def test_id(self):
158168
def test_number_overloading(self):
159169
a = pybamm.Scalar(4)
160170
prod = a * 3
161-
self.assertIsInstance(prod.children[1], pybamm.Scalar)
171+
self.assertIsInstance(prod, pybamm.Scalar)
162172
self.assertEqual(prod.evaluate(), 12)
163173

164174
def test_sparse_multiply(self):

tests/unit/test_expression_tree/test_symbol.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def test_symbol_evaluates_to_number(self):
194194
a = pybamm.Parameter("a")
195195
self.assertFalse(a.evaluates_to_number())
196196

197-
a = pybamm.Scalar(3) * pybamm.Scalar(2)
197+
a = pybamm.Scalar(3) * pybamm.Time()
198198
self.assertTrue(a.evaluates_to_number())
199199
# highlight difference between this function and isinstance(a, Scalar)
200200
self.assertNotIsInstance(a, pybamm.Scalar)
@@ -339,10 +339,10 @@ def test_has_spatial_derivatives(self):
339339

340340
def test_orphans(self):
341341
a = pybamm.Scalar(1)
342-
b = pybamm.Scalar(2)
343-
sum = a + b
342+
b = pybamm.Parameter("b")
343+
summ = a + b
344344

345-
a_orp, b_orp = sum.orphans
345+
a_orp, b_orp = summ.orphans
346346
self.assertIsNone(a_orp.parent)
347347
self.assertIsNone(b_orp.parent)
348348
self.assertEqual(a.id, a_orp.id)

tests/unit/test_expression_tree/test_unary_operators.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -124,22 +124,23 @@ def test_integral(self):
124124
pybamm.Integral(a, y)
125125

126126
def test_index(self):
127-
vec = pybamm.Vector(np.array([1, 2, 3, 4, 5]))
127+
vec = pybamm.StateVector(slice(0, 5))
128+
y_test = np.array([1, 2, 3, 4, 5])
128129
# with integer
129130
ind = vec[3]
130131
self.assertIsInstance(ind, pybamm.Index)
131132
self.assertEqual(ind.slice, slice(3, 4))
132-
self.assertEqual(ind.evaluate(), 4)
133+
self.assertEqual(ind.evaluate(y=y_test), 4)
133134
# with slice
134135
ind = vec[1:3]
135136
self.assertIsInstance(ind, pybamm.Index)
136137
self.assertEqual(ind.slice, slice(1, 3))
137-
np.testing.assert_array_equal(ind.evaluate(), np.array([[2], [3]]))
138+
np.testing.assert_array_equal(ind.evaluate(y=y_test), np.array([[2], [3]]))
138139
# with only stop slice
139140
ind = vec[:3]
140141
self.assertIsInstance(ind, pybamm.Index)
141142
self.assertEqual(ind.slice, slice(3))
142-
np.testing.assert_array_equal(ind.evaluate(), np.array([[1], [2], [3]]))
143+
np.testing.assert_array_equal(ind.evaluate(y=y_test), np.array([[1], [2], [3]]))
143144

144145
# errors
145146
with self.assertRaisesRegex(TypeError, "index must be integer or slice"):

0 commit comments

Comments
 (0)