diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index e679d7a785..fe0a15c3e0 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -808,6 +808,11 @@ class IntelDevice(Device): max_mem_trans_nbytes = 64 + def __init__(self, *args, sub_group_size=32, **kwargs): + super().__init__(*args, **kwargs) + + self.sub_group_size = sub_group_size + @property def march(self): return '' @@ -894,10 +899,10 @@ def march(cls): AMDGPUX = AmdDevice('amdgpuX') INTELGPUX = IntelDevice('intelgpuX') -PVC = IntelDevice('pvc', max_threads_per_block=4096) # Legacy codename for MAX GPUs -INTELGPUMAX = IntelDevice('intelgpuMAX', max_threads_per_block=4096) -MAX1100 = IntelDevice('max1100', max_threads_per_block=4096) -MAX1550 = IntelDevice('max1550', max_threads_per_block=4096) +PVC = IntelDevice('pvc') # Legacy codename for MAX GPUs +INTELGPUMAX = IntelDevice('intelgpuMAX') +MAX1100 = IntelDevice('max1100') +MAX1550 = IntelDevice('max1550') platform_registry = Platform.registry platform_registry['cpu64'] = get_platform # Autodetection diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 136e7bbfee..72307c5ac5 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -21,13 +21,14 @@ Symbol) from devito.types.object import AbstractObject, LocalObject -__all__ = ['Node', 'Block', 'Expression', 'Callable', 'Call', 'ExprStmt', - 'Conditional', 'Iteration', 'List', 'Section', 'TimedList', 'Prodder', - 'MetaCall', 'PointerCast', 'HaloSpot', 'Definition', 'ExpressionBundle', - 'AugmentedExpression', 'Increment', 'Return', 'While', 'ListMajor', - 'ParallelIteration', 'ParallelBlock', 'Dereference', 'Lambda', - 'SyncSpot', 'Pragma', 'DummyExpr', 'BlankLine', 'ParallelTree', - 'BusyWait', 'CallableBody', 'Transfer'] +__all__ = ['Node', 'MultiTraversable', 'Block', 'Expression', 'Callable', + 'Call', 'ExprStmt', 'Conditional', 'Iteration', 'List', 'Section', + 'TimedList', 'Prodder', 'MetaCall', 'PointerCast', 'HaloSpot', + 'Definition', 'ExpressionBundle', 'AugmentedExpression', + 'Increment', 'Return', 'While', 'ListMajor', 'ParallelIteration', + 'ParallelBlock', 'Dereference', 'Lambda', 'SyncSpot', 'Pragma', + 'DummyExpr', 'BlankLine', 'ParallelTree', 'BusyWait', 'UsingNamespace', + 'CallableBody', 'Transfer'] # First-class IET nodes @@ -175,6 +176,15 @@ class ExprStmt(object): pass +class MultiTraversable(Node): + + """ + An abstract base class for Nodes comprising more than one traversable children. + """ + + pass + + class List(Node): """A sequence of Nodes.""" @@ -740,7 +750,7 @@ def defines(self): return self.all_parameters -class CallableBody(Node): +class CallableBody(MultiTraversable): """ The immediate child of a Callable. @@ -1057,7 +1067,7 @@ class Lambda(Node): A callable C++ lambda function. Several syntaxes are possible; here we implement one of the common ones: - [captures](parameters){body} + [captures](parameters){body} SPECIAL [[attributes]] For more info about C++ lambda functions: @@ -1071,14 +1081,21 @@ class Lambda(Node): The captures of the lambda function. parameters : list of Basic or expr-like, optional The objects in input to the lambda function. + special : list of Basic, optional + Placeholder for custom lambdas, to add in e.g. macros. + attributes : list of str, optional + The attributes of the lambda function. """ _traversable = ['body'] - def __init__(self, body, captures=None, parameters=None): + def __init__(self, body, captures=None, parameters=None, special=None, + attributes=None): self.body = as_tuple(body) self.captures = as_tuple(captures) self.parameters = as_tuple(parameters) + self.special = as_tuple(special) + self.attributes = as_tuple(attributes) def __repr__(self): return "Lambda[%s](%s)" % (self.captures, self.parameters) @@ -1178,6 +1195,19 @@ def periodic(self): return self._periodic +class UsingNamespace(Node): + + """ + A C++ using namespace directive. + """ + + def __init__(self, namespace): + self.namespace = namespace + + def __repr__(self): + return "" % self.namespace + + class Pragma(Node): """ diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index f6e266ed9d..d94f2f1d78 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -573,7 +573,7 @@ def visit_Callable(self, o): signature = self._gen_signature(o) return c.FunctionBody(signature, c.Block(body)) - def visit_CallableBody(self, o): + def visit_MultiTraversable(self, o): body = [] prev = None for i in o.children: @@ -585,6 +585,9 @@ def visit_CallableBody(self, o): body.extend(as_tuple(v)) return c.Collection(body) + def visit_UsingNamespace(self, o): + return c.Statement('using namespace %s' % ccode(o.namespace)) + def visit_Lambda(self, o): body = [] for i in o.children: @@ -595,7 +598,15 @@ def visit_Lambda(self, o): body.extend(as_tuple(v)) captures = [str(i) for i in o.captures] decls = [i.inline() for i in self._args_decl(o.parameters)] - top = c.Line('[%s](%s)' % (', '.join(captures), ', '.join(decls))) + extra = [] + if o.special: + extra.append(' ') + extra.append(' '.join(str(i) for i in o.special)) + if o.attributes: + extra.append(' ') + extra.append(' '.join('[[%s]]' % i for i in o.attributes)) + top = c.Line('[%s](%s)%s' % + (', '.join(captures), ', '.join(decls), ''.join(extra))) return LambdaCollection([top, c.Block(body)]) def visit_HaloSpot(self, o): @@ -677,7 +688,7 @@ def visit_Operator(self, o, mode='all'): includes = self._operator_includes(o) + [blankline] # Namespaces - namespaces = [c.Statement('using namespace %s' % i) for i in o._namespaces] + namespaces = [self._visit(i) for i in o._namespaces] if namespaces: namespaces.append(blankline) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index b8de8bade8..8ff49b5e5e 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -16,10 +16,10 @@ __all__ = ['CondEq', 'CondNe', 'IntDiv', 'CallFromPointer', # noqa 'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', - 'InlineIf', 'ReservedWord', 'Keyword', 'String', 'Macro', 'Class', - 'MacroArgument', 'CustomType', 'Deref', 'Namespace', 'Rvalue', - 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null', 'SizeOf', 'rfunc', - 'cast_mapper', 'BasicWrapperMixin'] + 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', + 'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref', + 'Namespace', 'Rvalue', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null', + 'SizeOf', 'rfunc', 'cast_mapper', 'BasicWrapperMixin'] class CondEq(sympy.Eq): @@ -541,8 +541,7 @@ class DefFunction(Function, Pickable): https://github.com/sympy/sympy/issues/4297 """ - __rargs__ = ('name', 'arguments') - __rkwargs__ = ('template',) + __rargs__ = ('name', 'arguments', 'template') def __new__(cls, name, arguments=None, template=None, **kwargs): if isinstance(name, str): @@ -609,6 +608,12 @@ def _sympystr(self, printer): __reduce_ex__ = Pickable.__reduce_ex__ +class MathFunction(DefFunction): + + # Supposed to involve real operands + is_commutative = True + + class InlineIf(sympy.Expr, Pickable): """ diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index b4ea829b70..220e4d27b8 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -2,7 +2,7 @@ from collections.abc import Iterable from functools import singledispatch -from sympy import Pow, Add, Mul, Min, Max, SympifyError, Tuple, sympify +from sympy import Pow, Add, Mul, Min, Max, S, SympifyError, Tuple, sympify from sympy.core.add import _addsort from sympy.core.mul import _mulsort @@ -146,6 +146,11 @@ def _(expr, args, kwargs): @_uxreplace_handle.register(Mul) def _(expr, args, kwargs): + # Perform some basic simplifications at least + args = [i for i in args if i != 1] + if any(i == 0 for i in args): + return S.Zero + if all(i.is_commutative for i in args): _mulsort(args) _eval_numbers(expr, args) diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 6924ffffbd..fbaf2757d2 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -241,6 +241,8 @@ def _print_DefFunction(self, expr): template = '' return "%s%s(%s)" % (expr.name, template, ','.join(arguments)) + _print_MathFunction = _print_DefFunction + def _print_Fallback(self, expr): return expr.__str__() diff --git a/tests/test_iet.py b/tests/test_iet.py index 6ee6c13ca7..41bae75821 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -144,6 +144,26 @@ def test_list_denesting(): assert str(l3) == str(l2) +def test_lambda(): + grid = Grid(shape=(4, 4, 4)) + x, y, z = grid.dimensions + + u = Function(name='u', grid=grid) + + e0 = DummyExpr(u.indexed, 1) + e1 = DummyExpr(u.indexed, 2) + + body = List(body=[e0, e1]) + lam = Lambda(body, ['='], [u.indexed], attributes=['my_attr']) + + assert str(lam) == """\ +[=](float *restrict u) [[my_attr]] +{ + u = 1; + u = 2; +}""" + + def test_make_cpp_parfor(): """ Test construction of a C++ parallel for. This excites the IET construction