Skip to content

Commit

Permalink
api: dix sympy assumptions for complex valued objects
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Mar 6, 2025
1 parent d197799 commit f250c62
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
17 changes: 10 additions & 7 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,6 @@ class AbstractSymbol(sympy.Symbol, Basic, Pickable, Evaluable):
is_Symbol = True

# SymPy default assumptions
is_real = True
is_imaginary = False
is_commutative = True

__rkwargs__ = ('name', 'dtype', 'is_const')
Expand Down Expand Up @@ -411,6 +409,14 @@ def _hashable_content(self):
def dtype(self):
return self._dtype

@property
def is_real(self):
return not self.is_imaginary

@property
def is_imaginary(self):
return np.iscomplexobj(self.dtype(0))

@property
def indices(self):
return ()
Expand Down Expand Up @@ -859,7 +865,6 @@ class AbstractFunction(sympy.Function, Basic, Pickable, Evaluable):
is_AbstractFunction = True

# SymPy default assumptions
is_imaginary = False
is_commutative = True

# Devito default assumptions
Expand Down Expand Up @@ -955,6 +960,8 @@ def _sympystr(self, printer, **kwargs):
return str(self)

_latex = _sympystr
is_real = AbstractSymbol.is_real
is_imaginary = AbstractSymbol.is_imaginary

def _pretty(self, printer, **kwargs):
return printer._print_Function(self, func_name=self.name)
Expand Down Expand Up @@ -1315,10 +1322,6 @@ def is_const(self):
def is_transient(self):
return self._is_transient

@property
def is_real(self):
return not np.iscomplex(self.dtype(0))

@property
def is_persistent(self):
"""
Expand Down
16 changes: 16 additions & 0 deletions tests/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,22 @@ def test_inner_sparse(self):
term2 = np.inner(rec0.data.reshape(-1), rec1.data.reshape(-1))
assert np.isclose(term1/term2 - 1, 0.0, rtol=0.0, atol=1e-5)

@pytest.mark.parametrize('dtype', [np.float32, np.complex64])
def test_norm_dense(self, dtype):
"""
Test that norm produces the correct result against NumPy
"""
grid = Grid((101, 101), extent=(1000., 1000.))

f = Function(name='f', grid=grid, dtype=dtype)

f.data[:] = 1 + np.random.randn(*f.shape).astype(grid.dtype)
if np.iscomplexobj(f.data):
f.data[:] += 1j*np.random.randn(*f.shape).astype(grid.dtype)
term1 = np.linalg.norm(f.data)
term2 = norm(f)
assert np.isclose(term1/term2 - 1, 0.0, rtol=0.0, atol=1e-5)

def test_norm_sparse(self):
"""
Test that norm produces the correct result against NumPy
Expand Down
13 changes: 13 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ def test_modified_sympy_assumptions():
assert s2 == s1


def test_real():
for dtype in [np.float32, np.complex64]:
c = Constant(name='c', dtype=dtype)
assert c.is_real is not np.iscomplexobj(dtype(0))
assert c.is_imaginary is np.iscomplexobj(dtype(0))
f = Function(name='f', dtype=dtype, grid=Grid((11,)))
assert f.is_real is not np.iscomplexobj(dtype(0))
assert f.is_imaginary is np.iscomplexobj(dtype(0))
s = dSymbol(name='s', dtype=dtype)
assert s.is_real is not np.iscomplexobj(dtype(0))
assert s.is_imaginary is np.iscomplexobj(dtype(0))


def test_constant():
c = Constant(name='c')

Expand Down

0 comments on commit f250c62

Please sign in to comment.