Skip to content

Commit 094ee49

Browse files
committed
#704 tested that the extrapolation is behaving properly
1 parent b6f3546 commit 094ee49

File tree

2 files changed

+134
-0
lines changed

2 files changed

+134
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#
2+
# Test for the operator class
3+
#
4+
import pybamm
5+
from tests import (
6+
get_mesh_for_testing,
7+
get_p2d_mesh_for_testing,
8+
get_1p1d_mesh_for_testing,
9+
)
10+
11+
import numpy as np
12+
from scipy.sparse import kron, eye
13+
import unittest
14+
15+
16+
def errors(pts, function, extrap):
17+
18+
domain = "test"
19+
x = pybamm.SpatialVariable("x", domain=domain)
20+
geometry = {
21+
domain: {"primary": {x: {"min": pybamm.Scalar(0), "max": pybamm.Scalar(1)}}}
22+
}
23+
submesh_types = {domain: pybamm.MeshGenerator(pybamm.Uniform1DSubMesh)}
24+
var_pts = {x: pts}
25+
mesh = pybamm.Mesh(geometry, submesh_types, var_pts)
26+
27+
spatial_methods = {"test": pybamm.FiniteVolume}
28+
disc = pybamm.Discretisation(mesh, spatial_methods)
29+
30+
disc.spatial_methods["test"].extrapolation = extrap
31+
var = pybamm.Variable("var", domain="test")
32+
left_extrap = pybamm.BoundaryValue(var, "left")
33+
right_extrap = pybamm.BoundaryValue(var, "right")
34+
35+
submesh = mesh["test"]
36+
y, l_true, r_true = function(submesh[0].nodes)
37+
38+
disc.set_variable_slices([var])
39+
left_extrap_processed = disc.process_symbol(left_extrap)
40+
right_extrap_processed = disc.process_symbol(right_extrap)
41+
42+
l_error = np.abs(l_true - left_extrap_processed.evaluate(None, y))
43+
r_error = np.abs(r_true - right_extrap_processed.evaluate(None, y))
44+
45+
return l_error, r_error
46+
47+
48+
def get_errors(function, extrap, pts):
49+
50+
l_errors = np.zeros(pts.shape)
51+
r_errors = np.zeros(pts.shape)
52+
53+
for i, pt in enumerate(pts):
54+
l_errors[i], r_errors[i] = errors(pt, function, extrap)
55+
56+
return l_errors, r_errors
57+
58+
59+
class TestExtrapolation(unittest.TestCase):
60+
def test_quadratic_convergence(self):
61+
62+
# all tests are performed on x in [0, 1]
63+
64+
def x_squared(x):
65+
y = x ** 2
66+
l_true = 0
67+
r_true = 1
68+
return y, l_true, r_true
69+
70+
pts = 10 ** np.arange(1, 6, 1)
71+
dx = 1 / pts
72+
l_errors_lin, r_errors_lin = get_errors(x_squared, "linear", pts)
73+
l_errors_quad, r_errors_quad = get_errors(x_squared, "quadratic", pts)
74+
75+
l_lin_rates = np.log(l_errors_lin[:-1] / l_errors_lin[1:]) / np.log(
76+
dx[:-1] / dx[1:]
77+
)
78+
79+
r_lin_rates = np.log(r_errors_lin[:-1] / r_errors_lin[1:]) / np.log(
80+
dx[:-1] / dx[1:]
81+
)
82+
83+
np.testing.assert_array_almost_equal(l_lin_rates, 2)
84+
np.testing.assert_array_almost_equal(r_lin_rates, 2)
85+
86+
# check quadratic is equal up to machine precision
87+
np.testing.assert_array_almost_equal(l_errors_quad, 0, decimal=14)
88+
np.testing.assert_array_almost_equal(r_errors_quad, 0, decimal=14)
89+
90+
def x_cubed(x):
91+
y = x ** 3
92+
l_true = 0
93+
r_true = 1
94+
return y, l_true, r_true
95+
96+
l_errors_lin, r_errors_lin = get_errors(x_squared, "linear", pts)
97+
98+
l_lin_rates = np.log(l_errors_lin[:-1] / l_errors_lin[1:]) / np.log(
99+
dx[:-1] / dx[1:]
100+
)
101+
102+
r_lin_rates = np.log(r_errors_lin[:-1] / r_errors_lin[1:]) / np.log(
103+
dx[:-1] / dx[1:]
104+
)
105+
106+
np.testing.assert_array_almost_equal(l_lin_rates, 2)
107+
np.testing.assert_array_almost_equal(r_lin_rates, 2)
108+
109+
# quadratic case
110+
pts = 5 ** np.arange(1, 7, 1)
111+
dx = 1 / pts
112+
l_errors_quad, r_errors_quad = get_errors(x_cubed, "quadratic", pts)
113+
114+
l_quad_rates = np.log(l_errors_quad[:-1] / l_errors_quad[1:]) / np.log(
115+
dx[:-1] / dx[1:]
116+
)
117+
118+
r_quad_rates = np.log(r_errors_quad[:-1] / r_errors_quad[1:]) / np.log(
119+
dx[:-1] / dx[1:]
120+
)
121+
122+
np.testing.assert_array_almost_equal(l_quad_rates, 3)
123+
np.testing.assert_array_almost_equal(r_quad_rates, 3, decimal=3)
124+
125+
126+
if __name__ == "__main__":
127+
print("Add -v for more debug output")
128+
import sys
129+
130+
if "-v" in sys.argv:
131+
debug = True
132+
pybamm.settings.debug_mode = True
133+
unittest.main()
134+

0 commit comments

Comments
 (0)