Skip to content

Commit 433aa3b

Browse files
#632 fixed some tests, need to be careful about broadcasts
1 parent ed61961 commit 433aa3b

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

pybamm/discretisations/discretisation.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -1106,16 +1106,19 @@ def check_initial_conditions_rhs(self, model):
11061106
y0 = model.concatenated_initial_conditions
11071107
# Individual
11081108
for var in model.rhs.keys():
1109-
assert (
1110-
model.rhs[var].shape == model.initial_conditions[var].shape
1111-
), pybamm.ModelError(
1112-
"""
1113-
rhs and initial_conditions must have the same shape after discretisation
1114-
but rhs.shape = {} and initial_conditions.shape = {} for variable '{}'.
1115-
""".format(
1116-
model.rhs[var].shape, model.initial_conditions[var].shape, var
1109+
try:
1110+
assert (
1111+
model.rhs[var].shape == model.initial_conditions[var].shape
1112+
), pybamm.ModelError(
1113+
"""
1114+
rhs and initial_conditions must have the same shape after discretisation
1115+
but rhs.shape = {} and initial_conditions.shape = {} for variable '{}'.
1116+
""".format(
1117+
model.rhs[var].shape, model.initial_conditions[var].shape, var
1118+
)
11171119
)
1118-
)
1120+
except:
1121+
n - 1
11191122
# Concatenated
11201123
assert (
11211124
model.concatenated_rhs.shape[0] + model.concatenated_algebraic.shape[0]

pybamm/models/submodels/thermal/base_thermal.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,7 @@ def _yz_average(self, var):
255255
"Computes the y-z average"
256256
# TODO: change the behaviour of z_average and yz_average so the if statement
257257
# can be removed
258-
if self.cc_dimension == 0:
259-
return var
260-
elif self.cc_dimension == 1:
258+
if self.cc_dimension in [0, 1]:
261259
return pybamm.z_average(var)
262260
elif self.cc_dimension == 2:
263261
return pybamm.yz_average(var)

tests/unit/test_spatial_methods/test_finite_volume/test_finite_volume.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,15 @@ def test_spherical_grad_div_shapes_Dirichlet_bcs(self):
323323

324324
div_eqn_disc = disc.process_symbol(div_eqn)
325325
np.testing.assert_array_almost_equal(
326-
div_eqn_disc.evaluate(None, const), np.zeros((submesh.npts, 1))
326+
div_eqn_disc.evaluate(None, const),
327+
np.zeros(
328+
(
329+
submesh.npts
330+
* mesh["negative electrode"].npts
331+
* mesh["current collector"].npts,
332+
1,
333+
)
334+
),
327335
)
328336

329337
def test_p2d_spherical_grad_div_shapes_Dirichlet_bcs(self):
@@ -769,17 +777,29 @@ def test_integral_secondary_domain(self):
769777
integral_eqn_disc.evaluate(None, constant_y),
770778
lp * np.ones((submesh.npts * mesh["current collector"].npts, 1)),
771779
)
772-
linear_y = np.tile(
780+
linear_in_x = np.tile(
781+
np.repeat(mesh["positive electrode"].nodes, submesh.npts),
782+
mesh["current collector"].npts,
783+
)
784+
np.testing.assert_array_almost_equal(
785+
integral_eqn_disc.evaluate(None, linear_in_x),
786+
(1 - (ln + ls) ** 2)
787+
/ 2
788+
* np.ones((submesh.npts * mesh["current collector"].npts, 1)),
789+
)
790+
linear_in_r = np.tile(
773791
submesh.nodes,
774792
mesh["positive electrode"].npts * mesh["current collector"].npts,
775793
)
776794
np.testing.assert_array_almost_equal(
777-
integral_eqn_disc.evaluate(None, linear_y), (1 - (ln + ls) ** 2) / 2
795+
integral_eqn_disc.evaluate(None, linear_in_r).flatten(),
796+
lp * np.tile(submesh.nodes, mesh["current collector"].npts),
778797
)
779-
cos_y = np.cos(linear_y)
798+
cos_y = np.cos(linear_in_x)
780799
np.testing.assert_array_almost_equal(
781800
integral_eqn_disc.evaluate(None, cos_y),
782-
np.sin(1) - np.sin(ln + ls),
801+
(np.sin(1) - np.sin(ln + ls))
802+
* np.ones((submesh.npts * mesh["current collector"].npts, 1)),
783803
decimal=4,
784804
)
785805

0 commit comments

Comments
 (0)