Skip to content

Commit

Permalink
dsl: Add tests for OOB accesses
Browse files Browse the repository at this point in the history
  • Loading branch information
Leitevmd committed Apr 9, 2021
1 parent 8d689d5 commit 61eff9b
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,64 @@ def test_shifted_grad(self, shift, ndim):
x0 = (None if shift is None else d + shift[i] * d.spacing if
type(shift) is tuple else d + shift * d.spacing)
assert gi == getattr(f, 'd%s' % d.name)(x0=x0).evaluate

@pytest.mark.parametrize('exprs, error,', [
### Dimensions
# OOB access avoided by iteration space shrinking
(['Eq(so2.forward, so2.dx.dx.dx)'],None),
(['Eq(so2.forward, (so2+so0).dx2 )'],None),
(['Eq(so2.forward, so2.biharmonic(1/so0) )'],None),
### SubDimensions (Thickness == 0)
# No OOB indices, no error (Dimension or SubDimension)
(['Eq(so1[t+1,x,y], so1[t,x-1,y] + so1[t,x+1,y] )'],None),
(['Eq(so1[t+1,xi0,y], so1[t,xi0-1,y] + so1[t,xi0+1,y] )'],None),
# These examples would get SEGFAULT because SubDimenions are not being handled
# We are raising exceptions, but we could also increase thickness
(['Eq(sd0so2.forward, sd0so2.dx.dx.dx)'],ValueError), # This should get segfault althoug it doesn't
(['Eq(sd0so2.forward, (sd0so2+sd0so0).dx2 )'],ValueError),
(['Eq(so0[t+1,xi0,y], so0[t,xi0-1,y] + so0[t,xi0+1,y] )'],ValueError),
### SubDimensions (Thickness == 1) (Not covered by current patch)
# These examples pass just because of thickness size support their OOB indices
(['Eq(sd1so2.forward, sd1so2.dx.dx.dx)'],None),
(['Eq(sd1so2.forward, (sd1so2+sd1so0).dx2 )'],None),
(['Eq(so1[t+1,xi1,y], so1[t,xi1-2,y] + so1[t,xi1+2,y] )'],None),
# This would still segfault
# (['Eq(so1[t+1,xi1,y], so1[t,xi1-3,y] + so1[t,xi1+3,y] )'],ValueError),
])
def test_oobs(self, exprs, error):

grid = Grid(tuple([100]*2))
x , y = grid.dimensions
t = grid.stepping_dim

xi0 = SubDimension.middle(name='xi0', parent=x, thickness_left=0, thickness_right=0)
xi1 = SubDimension.middle(name='xi1', parent=x, thickness_left=1, thickness_right=1)

so0 = TimeFunction(name="so0", grid=grid, time_order=1, space_order=0)
so1 = TimeFunction(name="so1", grid=grid, time_order=1, space_order=1)
so2 = TimeFunction(name="so2", grid=grid, time_order=1, space_order=2)

sd0so0 = TimeFunction(name="sdso0", grid=grid, time_order=1, space_order=0, dimensions=(t,xi0,y))
sd0so1 = TimeFunction(name="sdso1", grid=grid, time_order=1, space_order=1, dimensions=(t,xi0,y))
sd0so2 = TimeFunction(name="sdso2", grid=grid, time_order=1, space_order=2, dimensions=(t,xi0,y))

sd1so0 = TimeFunction(name="sdso0", grid=grid, time_order=1, space_order=0, dimensions=(t,xi1,y))
sd1so1 = TimeFunction(name="sdso1", grid=grid, time_order=1, space_order=1, dimensions=(t,xi1,y))
sd1so2 = TimeFunction(name="sdso2", grid=grid, time_order=1, space_order=2, dimensions=(t,xi1,y))

# List comprehension would need explicit locals/globals mappings to eval
for i, e in enumerate(list(exprs)):
exprs[i] = eval(e)
try:
op = Operator(exprs)
except Exception as e:
if error is None or not isinstance(e, error):
assert False, "Not expected: %s %s" % (type(e),e)
else:
if error is not None:
assert False, "Should raise an %s exception" % error
print('arguments: ', op.arguments(time_M=0))
print('operator:\n',op)
op.apply(time_M=5)

0 comments on commit 61eff9b

Please sign in to comment.