Skip to content

Commit

Permalink
Merge pull request #2245 from devitocodes/funcs_on_subdims
Browse files Browse the repository at this point in the history
dsl: Introduce ability to define Functions on Subdomains
  • Loading branch information
FabioLuporini authored Mar 4, 2025
2 parents 34dba05 + 3b8ec13 commit f854b0c
Show file tree
Hide file tree
Showing 26 changed files with 5,416 additions and 572 deletions.
12 changes: 7 additions & 5 deletions devito/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,12 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
stop += sendcounts[i]
data_slice = recvbuf[slice(start, stop, step)]
shape = [r.stop-r.start for r in self._distributor.all_ranges[i]]
idx = [slice(r.start, r.stop, r.step)
for r in self._distributor.all_ranges[i]]
for i in range(len(self.shape) - len(self._distributor.glb_shape)):
shape.insert(i, glb_shape[i])
idx.insert(i, slice(0, glb_shape[i]+1, 1))
idx = [slice(r.start - d.glb_min, r.stop - d.glb_min, r.step)
for r, d in zip(self._distributor.all_ranges[i],
self._distributor.decomposition)]
for j in range(len(self.shape) - len(self._distributor.glb_shape)):
shape.insert(j, glb_shape[j])
idx.insert(j, slice(0, glb_shape[j]+1, 1))
retval[tuple(idx)] = data_slice.reshape(tuple(shape))
return retval
else:
Expand Down Expand Up @@ -329,6 +330,7 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
@_check_idx
def __setitem__(self, glb_idx, val, comm_type):
loc_idx = self._index_glb_to_loc(glb_idx)

if loc_idx is NONLOCAL:
# no-op
return
Expand Down
79 changes: 49 additions & 30 deletions devito/data/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,26 @@ def index_glb_to_loc(self, *args, rel=True):
>>> d.index_glb_to_loc((1, 6), rel=False)
(5, 6)
"""
# Offset the loc_abs_min, loc_abs_max, glb_min, and glb_max
# In the case of a Function defined on a SubDomain, the global indices
# for accessing the data associated with this Function will run from
# `ltkn` to `x_M-rtkn-1`. However, the indices for accessing this array
# will run from `0` to `x_M-ltkn-rtkn-1`. As such, the global minimum
# (`ltkn`) should be subtracted for the purpose of indexing into the local
# array.
if not self.loc_empty:
loc_abs_min = self.loc_abs_min - self.glb_min
loc_abs_max = self.loc_abs_max - self.glb_min
glb_max = self.glb_max - self.glb_min
else:
loc_abs_min = self.loc_abs_min
loc_abs_max = self.loc_abs_max
glb_max = self.glb_max

glb_min = 0

base = self.loc_abs_min if rel is True else 0
top = self.loc_abs_max
base = loc_abs_min if rel else 0
top = loc_abs_max

if len(args) == 1:
glb_idx = args[0]
Expand All @@ -217,11 +234,11 @@ def index_glb_to_loc(self, *args, rel=True):
return None
# -> Handle negative index
if glb_idx < 0:
glb_idx = self.glb_max + glb_idx + 1
glb_idx = glb_max + glb_idx + 1
# -> Do the actual conversion
if self.loc_abs_min <= glb_idx <= self.loc_abs_max:
if loc_abs_min <= glb_idx <= loc_abs_max:
return glb_idx - base
elif self.glb_min <= glb_idx <= self.glb_max:
elif glb_min <= glb_idx <= glb_max:
return None
else:
# This should raise an exception when used to access a numpy.array
Expand All @@ -239,30 +256,32 @@ def index_glb_to_loc(self, *args, rel=True):
elif isinstance(glb_idx, slice):
if self.loc_empty:
return slice(-1, -3)
if glb_idx.step >= 0 and glb_idx.stop == self.glb_min:
glb_idx_min = self.glb_min if glb_idx.start is None \
if glb_idx.step >= 0 and glb_idx.stop == glb_min:
glb_idx_min = glb_min if glb_idx.start is None \
else glb_idx.start
glb_idx_max = self.glb_min
glb_idx_max = glb_min
retfunc = lambda a, b: slice(a, b, glb_idx.step)
elif glb_idx.step >= 0:
glb_idx_min = self.glb_min if glb_idx.start is None \
glb_idx_min = glb_min if glb_idx.start is None \
else glb_idx.start
glb_idx_max = self.glb_max if glb_idx.stop is None \
glb_idx_max = glb_max \
if glb_idx.stop is None \
else glb_idx.stop-1
retfunc = lambda a, b: slice(a, b + 1, glb_idx.step)
else:
glb_idx_min = self.glb_min if glb_idx.stop is None \
glb_idx_min = glb_min if glb_idx.stop is None \
else glb_idx.stop+1
glb_idx_max = self.glb_max if glb_idx.start is None \
glb_idx_max = glb_max if glb_idx.start is None \
else glb_idx.start
retfunc = lambda a, b: slice(b, a - 1, glb_idx.step)
else:
raise TypeError("Cannot convert index from `%s`" % type(glb_idx))
# -> Handle negative min/max
if glb_idx_min is not None and glb_idx_min < 0:
glb_idx_min = self.glb_max + glb_idx_min + 1
glb_idx_min = glb_max + glb_idx_min + 1
if glb_idx_max is not None and glb_idx_max < 0:
glb_idx_max = self.glb_max + glb_idx_max + 1
glb_idx_max = glb_max + glb_idx_max + 1

# -> Do the actual conversion
# Compute loc_min. For a slice with step > 0 this will be
# used to produce slice.start and for a slice with step < 0 slice.stop.
Expand All @@ -271,19 +290,19 @@ def index_glb_to_loc(self, *args, rel=True):
# coincide with loc_abs_min.
if isinstance(glb_idx, slice) and glb_idx.step is not None \
and glb_idx.step > 1:
if glb_idx_min > self.loc_abs_max:
if glb_idx_min > loc_abs_max:
return retfunc(-1, -3)
elif glb_idx.start is None: # glb start is zero.
loc_min = self.loc_abs_min - base \
loc_min = loc_abs_min - base \
+ np.mod(glb_idx.step - np.mod(base, glb_idx.step),
glb_idx.step)
else: # glb start is given explicitly
loc_min = self.loc_abs_min - base \
loc_min = loc_abs_min - base \
+ np.mod(glb_idx.step - np.mod(base - glb_idx.start,
glb_idx.step), glb_idx.step)
elif glb_idx_min is None or glb_idx_min < self.loc_abs_min:
loc_min = self.loc_abs_min - base
elif glb_idx_min > self.loc_abs_max:
elif glb_idx_min is None or glb_idx_min < loc_abs_min:
loc_min = loc_abs_min - base
elif glb_idx_min > loc_abs_max:
return retfunc(-1, -3)
else:
loc_min = glb_idx_min - base
Expand All @@ -294,19 +313,19 @@ def index_glb_to_loc(self, *args, rel=True):
# coincide with loc_abs_max.
if isinstance(glb_idx, slice) and glb_idx.step is not None \
and glb_idx.step < -1:
if glb_idx_max < self.loc_abs_min:
if glb_idx_max < loc_abs_min:
return retfunc(-1, -3)
elif glb_idx.start is None:
loc_max = top - base \
+ np.mod(glb_idx.step - np.mod(top - self.glb_max,
+ np.mod(glb_idx.step - np.mod(top - glb_max,
glb_idx.step), glb_idx.step)
else:
loc_max = top - base \
+ np.mod(glb_idx.step - np.mod(top - glb_idx.start,
glb_idx.step), glb_idx.step)
elif glb_idx_max is None or glb_idx_max > self.loc_abs_max:
loc_max = self.loc_abs_max - base
elif glb_idx_max < self.loc_abs_min:
elif glb_idx_max is None or glb_idx_max > loc_abs_max:
loc_max = loc_abs_max - base
elif glb_idx_max < loc_abs_min:
return retfunc(-1, -3)
else:
loc_max = glb_idx_max - base
Expand All @@ -321,19 +340,19 @@ def index_glb_to_loc(self, *args, rel=True):
return None
abs_ofs, side = args
if side == LEFT:
rel_ofs = self.glb_min + abs_ofs - base
rel_ofs = glb_min + abs_ofs - base
if abs_ofs >= base and abs_ofs <= top:
return rel_ofs
elif abs_ofs > top:
return top + 1
else:
return None
else:
rel_ofs = abs_ofs - (self.glb_max - top)
if abs_ofs >= self.glb_max - top and abs_ofs <= self.glb_max - base:
rel_ofs = abs_ofs - (glb_max - top)
if abs_ofs >= glb_max - top and abs_ofs <= glb_max - base:
return rel_ofs
elif abs_ofs > self.glb_max - base:
return self.glb_max - base + 1
elif abs_ofs > glb_max - base:
return glb_max - base + 1
else:
return None
else:
Expand Down
13 changes: 11 additions & 2 deletions devito/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,23 @@ class DevitoDeprecation():
@cached_property
def coeff_warn(self):
warn("The Coefficient API is deprecated and will be removed, coefficients should"
"be passed directly to the derivative object `u.dx(weights=...)",
" be passed directly to the derivative object `u.dx(weights=...)",
DeprecationWarning, stacklevel=2)
return

@cached_property
def symbolic_warn(self):
warn("coefficients='symbolic' is deprecated, coefficients should"
"be passed directly to the derivative object `u.dx(weights=...)",
" be passed directly to the derivative object `u.dx(weights=...)",
DeprecationWarning, stacklevel=2)
return

@cached_property
def subdomain_warn(self):
warn("Passing `SubDomain`s to `Grid` on instantiation using `mygrid ="
" Grid(..., subdomains=(mydomain, ...))` is deprecated. The `Grid`"
" should instead be passed as a kwarg when instantiating a subdomain"
" `mydomain = MyDomain(grid=mygrid)`",
DeprecationWarning, stacklevel=2)
return

Expand Down
14 changes: 10 additions & 4 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def time_order(self):
@cached_property
def grid(self):
grids = {getattr(i, 'grid', None) for i in self._args_diff} - {None}
grids = {g.root for g in grids}
if len(grids) > 1:
warning("Expression contains multiple grids, returning first found")
try:
Expand All @@ -86,6 +87,11 @@ def dimensions(self):
return tuple(filter_ordered(flatten(getattr(i, 'dimensions', ())
for i in self._args_diff)))

@cached_property
def root_dimensions(self):
"""Tuple of root Dimensions of the physical space Dimensions."""
return tuple(d.root for d in self.dimensions if d.is_Space)

@property
def indices_ref(self):
"""The reference indices of the object (indices at first creation)."""
Expand Down Expand Up @@ -317,7 +323,7 @@ def laplacian(self, shift=None, order=None, method='FD', **kwargs):
"""
w = kwargs.get('weights', kwargs.get('w'))
order = order or self.space_order
space_dims = [d for d in self.dimensions if d.is_Space]
space_dims = self.root_dimensions
shift_x0 = make_shift_x0(shift, (len(space_dims),))
derivs = tuple('d%s2' % d.name for d in space_dims)
return Add(*[getattr(self, d)(x0=shift_x0(shift, space_dims[i], None, i),
Expand All @@ -344,7 +350,7 @@ def div(self, shift=None, order=None, method='FD', **kwargs):
Custom weights for the finite difference coefficients.
"""
w = kwargs.get('weights', kwargs.get('w'))
space_dims = [d for d in self.dimensions if d.is_Space]
space_dims = self.root_dimensions
shift_x0 = make_shift_x0(shift, (len(space_dims),))
order = order or self.space_order
return Add(*[getattr(self, 'd%s' % d.name)(x0=shift_x0(shift, d, None, i),
Expand All @@ -371,7 +377,7 @@ def grad(self, shift=None, order=None, method='FD', **kwargs):
Custom weights for the finite
"""
from devito.types.tensor import VectorFunction, VectorTimeFunction
space_dims = [d for d in self.dimensions if d.is_Space]
space_dims = self.root_dimensions
shift_x0 = make_shift_x0(shift, (len(space_dims),))
order = order or self.space_order
w = kwargs.get('weights', kwargs.get('w'))
Expand All @@ -387,7 +393,7 @@ def biharmonic(self, weight=1):
Generates a symbolic expression for the weighted biharmonic operator w.r.t.
all spatial Dimensions Laplace(weight * Laplace (self))
"""
space_dims = [d for d in self.dimensions if d.is_Space]
space_dims = self.root_dimensions
derivs = tuple('d%s2' % d.name for d in space_dims)
return Add(*[getattr(self.laplace * weight, d) for d in derivs])

Expand Down
1 change: 1 addition & 0 deletions devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def functions(self):
def grid(self):
grids = set(f.grid for f in self.functions if f.is_AbstractFunction)
grids.discard(None)
grids = {g.root for g in grids}
if len(grids) == 0:
return None
elif len(grids) == 1:
Expand Down
42 changes: 35 additions & 7 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from devito.symbolics import (retrieve_indexed, uxreplace, retrieve_dimensions,
retrieve_functions)
from devito.tools import Ordering, as_tuple, flatten, filter_sorted, filter_ordered
from devito.tools import (Ordering, as_tuple, flatten, filter_sorted, filter_ordered,
frozendict)
from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension,
ConditionalDimension)
from devito.types.array import Array
Expand Down Expand Up @@ -112,11 +113,7 @@ def lower_exprs(expressions, subs=None, **kwargs):
def _lower_exprs(expressions, subs):
processed = []
for expr in as_tuple(expressions):
try:
dimension_map = expr.subdomain.dimension_map
except AttributeError:
# Some Relationals may be pure SymPy objects, thus lacking the subdomain
dimension_map = {}
dimension_map = _make_dimension_map(expr)

# Handle Functions (typical case)
mapper = {f: _lower_exprs(f.indexify(subs=dimension_map), subs)
Expand Down Expand Up @@ -160,6 +157,30 @@ def _lower_exprs(expressions, subs):
return processed.pop()


def _make_dimension_map(expr):
"""
Make the dimension_map for an expression. In the basic case, this is extracted
directly from the SubDomain attached to the expression.
The indices of a Function defined on a SubDomain will all be the SubDimensions of
that SubDomain. In this case, the dimension_map should be extended with
`{ix_f: ix_i, iy_f: iy_i}` where `ix_f` is the SubDimension on which the Function is
defined, and `ix_i` is the SubDimension to be iterated over.
"""
try:
dimension_map = {**expr.subdomain.dimension_map}
except AttributeError:
# Some Relationals may be pure SymPy objects, thus lacking the SubDomain
dimension_map = {}
else:
functions = [f for f in retrieve_functions(expr) if f._is_on_subdomain]
for f in functions:
dimension_map.update({d: expr.subdomain.dimension_map[d.root]
for d in f.space_dimensions if d.is_Sub})

return frozendict(dimension_map)


def concretize_subdims(exprs, **kwargs):
"""
Given a list of expressions, return a new list where all user-defined
Expand Down Expand Up @@ -206,7 +227,14 @@ def _(v, mapper, rebuilt, sregistry):

@_concretize_subdims.register(Eq)
def _(expr, mapper, rebuilt, sregistry):
for d in expr.free_symbols:
# Split and reorder symbols so SubDimensions are processed before lone Thicknesses
# This means that if a Thickness appears both in the expression and attached to
# a SubDimension, it gets concretised with the SubDimension.
thicknesses = {i for i in expr.free_symbols if isinstance(i, Thickness)}
symbols = expr.free_symbols.difference(thicknesses)

# Iterate over all other symbols before iterating over standalone thicknesses
for d in tuple(symbols) + tuple(thicknesses):
_concretize_subdims(d, mapper, rebuilt, sregistry)

# Subdimensions can be hiding in implicit dims
Expand Down
Loading

0 comments on commit f854b0c

Please sign in to comment.