Skip to content

Commit fb794ff

Browse files
committed
#704 updated tests
1 parent 89d0d16 commit fb794ff

File tree

3 files changed

+21
-11
lines changed

3 files changed

+21
-11
lines changed

pybamm/spatial_methods/spatial_method.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,13 @@ def __init__(self, options=None):
2424

2525
self.options = {"extrapolation": {"order": "quadratic", "use bcs": True}}
2626

27+
# update double-layered dict
2728
if options:
28-
self.options.update(options)
29+
for opt, val in options.items():
30+
if isinstance(val, dict):
31+
self.options[opt].update(val)
32+
else:
33+
self.options[opt] = val
2934

3035
self._mesh = None
3136

tests/unit/test_spatial_methods/test_finite_volume/__init__.py

Whitespace-only changes.

tests/unit/test_spatial_methods/test_finite_volume/test_extrapolation.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import unittest
1616

1717

18-
def errors(pts, function, extrap):
18+
def errors(pts, function, method_options):
1919

2020
domain = "test"
2121
x = pybamm.SpatialVariable("x", domain=domain)
@@ -26,10 +26,9 @@ def errors(pts, function, extrap):
2626
var_pts = {x: pts}
2727
mesh = pybamm.Mesh(geometry, submesh_types, var_pts)
2828

29-
spatial_methods = {"test": pybamm.FiniteVolume}
29+
spatial_methods = {"test": pybamm.FiniteVolume(method_options)}
3030
disc = pybamm.Discretisation(mesh, spatial_methods)
3131

32-
disc.spatial_methods["test"].extrapolation = extrap
3332
var = pybamm.Variable("var", domain="test")
3433
left_extrap = pybamm.BoundaryValue(var, "left")
3534
right_extrap = pybamm.BoundaryValue(var, "right")
@@ -47,21 +46,23 @@ def errors(pts, function, extrap):
4746
return l_error, r_error
4847

4948

50-
def get_errors(function, extrap, pts):
49+
def get_errors(function, method_options, pts):
5150

5251
l_errors = np.zeros(pts.shape)
5352
r_errors = np.zeros(pts.shape)
5453

5554
for i, pt in enumerate(pts):
56-
l_errors[i], r_errors[i] = errors(pt, function, extrap)
55+
l_errors[i], r_errors[i] = errors(pt, function, method_options)
5756

5857
return l_errors, r_errors
5958

6059

6160
class TestExtrapolation(unittest.TestCase):
62-
def test_quadratic_convergence(self):
61+
def test_convergence_without_bcs(self):
6362

6463
# all tests are performed on x in [0, 1]
64+
linear = {"extrapolation": {"order": "linear"}}
65+
quad = {"extrapolation": {"order": "quadratic"}}
6566

6667
def x_squared(x):
6768
y = x ** 2
@@ -71,8 +72,9 @@ def x_squared(x):
7172

7273
pts = 10 ** np.arange(1, 6, 1)
7374
dx = 1 / pts
74-
l_errors_lin, r_errors_lin = get_errors(x_squared, "linear", pts)
75-
l_errors_quad, r_errors_quad = get_errors(x_squared, "quadratic", pts)
75+
76+
l_errors_lin, r_errors_lin = get_errors(x_squared, linear, pts)
77+
l_errors_quad, r_errors_quad = get_errors(x_squared, quad, pts)
7678

7779
l_lin_rates = np.log(l_errors_lin[:-1] / l_errors_lin[1:]) / np.log(
7880
dx[:-1] / dx[1:]
@@ -95,7 +97,7 @@ def x_cubed(x):
9597
r_true = 1
9698
return y, l_true, r_true
9799

98-
l_errors_lin, r_errors_lin = get_errors(x_squared, "linear", pts)
100+
l_errors_lin, r_errors_lin = get_errors(x_squared, linear, pts)
99101

100102
l_lin_rates = np.log(l_errors_lin[:-1] / l_errors_lin[1:]) / np.log(
101103
dx[:-1] / dx[1:]
@@ -111,7 +113,7 @@ def x_cubed(x):
111113
# quadratic case
112114
pts = 5 ** np.arange(1, 7, 1)
113115
dx = 1 / pts
114-
l_errors_quad, r_errors_quad = get_errors(x_cubed, "quadratic", pts)
116+
l_errors_quad, r_errors_quad = get_errors(x_cubed, quad, pts)
115117

116118
l_quad_rates = np.log(l_errors_quad[:-1] / l_errors_quad[1:]) / np.log(
117119
dx[:-1] / dx[1:]
@@ -124,6 +126,9 @@ def x_cubed(x):
124126
np.testing.assert_array_almost_equal(l_quad_rates, 3)
125127
np.testing.assert_array_almost_equal(r_quad_rates, 3, decimal=3)
126128

129+
def test_extrapolation_with_bcs(self):
130+
# simple particle with a flux bc
131+
127132

128133
if __name__ == "__main__":
129134
print("Add -v for more debug output")

0 commit comments

Comments
 (0)