Skip to content

Commit

Permalink
api: fix norm with complex numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Mar 6, 2025
1 parent 464fcf3 commit d197799
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 17 deletions.
4 changes: 2 additions & 2 deletions devito/builtins/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def norm(f, order=2):
s = dv.types.Symbol(name='sum', dtype=n.dtype)

op = dv.Operator([dv.Eq(s, 0.0)] + eqns +
[dv.Inc(s, dv.Abs(Pow(p, order))), dv.Eq(n[0], s)],
[dv.Inc(s, Pow(dv.Abs(p), order)), dv.Eq(n[0], s)],
name='norm%d' % order)
op.apply(**kwargs)

v = np.power(n.data[0], 1/order)

return f.dtype(v)
return np.real(f.dtype(v))


@dv.switchconfig(log_level='ERROR')
Expand Down
5 changes: 4 additions & 1 deletion devito/builtins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
# NOTE: np.float128 isn't really a thing, see for example
# https://github.com/numpy/numpy/issues/10288
# https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1070
np.float64: np.float64
np.float64: np.float64,
# ComplexX accumulates on Complex2X
np.complex64: np.complex128,
np.complex128: np.complex128,
}


Expand Down
5 changes: 4 additions & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,6 @@ class AbstractFunction(sympy.Function, Basic, Pickable, Evaluable):
is_AbstractFunction = True

# SymPy default assumptions
is_real = True
is_imaginary = False
is_commutative = True

Expand Down Expand Up @@ -1316,6 +1315,10 @@ def is_const(self):
def is_transient(self):
return self._is_transient

@property
def is_real(self):
return not np.iscomplex(self.dtype(0))

@property
def is_persistent(self):
"""
Expand Down
32 changes: 19 additions & 13 deletions examples/seismic/tutorials/17_fourier_mode.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
"outputs": [],
"source": [
"from devito import *\n",
"\n",
"from examples.seismic import demo_model, AcquisitionGeometry, plot_velocity\n",
"from examples.seismic.acoustic import AcousticWaveSolver\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt"
]
},
Expand All @@ -42,6 +43,7 @@
}
],
"source": [
"#NBVAL_IGNORE_OUTPUT\n",
"model = demo_model('layers-isotropic', vp=3.0, origin=(0., 0.), shape=(101, 101), spacing=(10., 10.), nbl=40, nlayers=4)"
]
},
Expand All @@ -62,6 +64,7 @@
}
],
"source": [
"#NBVAL_IGNORE_OUTPUT\n",
"plot_velocity(model)"
]
},
Expand Down Expand Up @@ -114,13 +117,6 @@
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"damp(x, y)*Derivative(u(t, x, y), t) - Derivative(u(t, x, y), (x, 2)) - Derivative(u(t, x, y), (y, 2)) + Derivative(u(t, x, y), (t, 2))/vp(x, y)**2\n"
]
},
{
"data": {
"text/latex": [
Expand All @@ -147,8 +143,6 @@
"# We can now write the PDE\n",
"pde = model.m * u.dt2 - u.laplace + model.damp * u.dt\n",
"\n",
"# The PDE representation is as on paper\n",
"print(pde)\n",
"\n",
"# Stencil update\n",
"stencil = Eq(u.forward, solve(pde, u.forward))\n",
Expand Down Expand Up @@ -209,11 +203,11 @@
"data": {
"text/plain": [
"PerformanceSummary([(PerfKey(name='section0', rank=None),\n",
" PerfEntry(time=0.016828, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.01551399999999999, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section1', rank=None),\n",
" PerfEntry(time=0.002213999999999991, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" PerfEntry(time=0.0023819999999999913, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[])),\n",
" (PerfKey(name='section2', rank=None),\n",
" PerfEntry(time=0.0021949999999999943, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
" PerfEntry(time=0.002333999999999994, gflopss=0.0, gpointss=0.0, oi=0.0, ops=0, itershapes=[]))])"
]
},
"execution_count": 9,
Expand All @@ -222,6 +216,7 @@
}
],
"source": [
"#NBVAL_IGNORE_OUTPUT\n",
"op(dt=model.critical_dt)"
]
},
Expand All @@ -242,6 +237,7 @@
}
],
"source": [
"#NBVAL_IGNORE_OUTPUT\n",
"plt.figure(figsize=(12, 6))\n",
"plt.subplot(1, 2, 1)\n",
"plt.imshow(np.real(freq_mode.data.T), cmap='seismic', vmin=-1e2, vmax=1e2)\n",
Expand All @@ -251,6 +247,16 @@
"plt.colorbar()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"assert np.isclose(norm(freq_mode), 13873.049, atol=0, rtol=1e-4)\n",
"assert np.isclose(norm(u), 323.74207, atol=0, rtol=1e-4)"
]
}
],
"metadata": {
Expand Down

0 comments on commit d197799

Please sign in to comment.