Skip to content

Commit 077d7a7

Browse files
committed
compiler: Fix MPI alignment by using alignment_elems
1 parent 86d0e63 commit 077d7a7

4 files changed

Lines changed: 61 additions & 34 deletions

File tree

devito/mpi/halo_scheme.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def __init__(self, exprs, ispace):
150150

151151
# Further constraints on the `omapper` derivation. At construction time
152152
# there's none, but lowering passes may change this
153+
# * `_alignment` may be a positive integer representing the alignment
154+
# requirement, in number of *elements*, of the underlying expressions
153155
self._alignment = None
154156

155157
def __repr__(self):

devito/passes/iet/mpi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def _mark_overlappable(iet):
319319
for hs in found:
320320
exprs = [n.expr for n in FindNodes(Expression).visit(hs)]
321321
objs = search(exprs, (VectorAccess, TensorMove))
322-
alignment = max([i._expected_alignment for i in objs], default=None)
322+
alignment = max([i._expected_alignment_elems for i in objs], default=None)
323323

324324
hsf = hs.halo_scheme._rebuild(alignment=alignment)
325325
hs1 = hs._rebuild(halo_scheme=hsf)

devito/symbolics/extended_sympy.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import re
55
from contextlib import suppress
6+
from functools import cached_property
67

78
import numpy as np
89
import sympy
@@ -25,7 +26,7 @@
2526
'MathFunction', 'InlineIf', 'Reserved', 'ReservedWord', 'Keyword',
2627
'String', 'Macro', 'Class', 'MacroArgument', 'RoundUp', 'Deref',
2728
'Namespace', 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin',
28-
'ValueLimit', 'VectorAccess']
29+
'ValueLimit', 'AlignedAccess', 'VectorAccess']
2930

3031

3132
class CondEq(sympy.Eq):
@@ -888,16 +889,47 @@ def __str__(self):
888889
__repr__ = __str__
889890

890891

891-
class VectorAccess(Expr, Pickable, BasicWrapperMixin):
892+
class AlignedAccess(Expr, Reserved, BasicWrapperMixin):
892893

893894
"""
894-
Represent a vector access operation at high-level.
895+
Abstract base class for an aligned access operation, that is an access to a
896+
memory location that is guaranteed to be aligned to a certain byte
897+
boundary.
895898
"""
896899

897-
_expected_alignment = 16
900+
@property
901+
def _expected_alignment(self):
902+
"""
903+
The expected alignment in bytes for the underlying LOAD/STORE operation.
904+
905+
To be implemented by subclasses.
906+
"""
907+
raise NotImplementedError
908+
909+
@property
910+
def _expected_alignment_elems(self):
911+
"""
912+
The expected alignment in number of elements for the underlying
913+
LOAD/STORE operation.
914+
"""
915+
return self._expected_alignment // self.dtype().itemsize
916+
917+
@property
918+
def base(self):
919+
return self.args[0]
920+
921+
func = Reserved._rebuild
922+
923+
@cacheit
924+
def sort_key(self, order=None):
925+
# Ensure that the AlignedAccess is sorted as the base
926+
return self.base.sort_key(order=order)
927+
928+
929+
class VectorAccess(AlignedAccess):
930+
898931
"""
899-
The expected alignment in bytes for the accessed vector. This must be
900-
honored by the compiler for correctness.
932+
Represent a vector access operation at high-level.
901933
"""
902934

903935
def __new__(cls, *args, **kwargs):
@@ -908,21 +940,28 @@ def __str__(self):
908940

909941
__repr__ = __str__
910942

911-
func = Pickable._rebuild
912-
913-
@property
914-
def base(self):
915-
return self.args[0]
943+
@cached_property
944+
def _expected_alignment(self):
945+
"""
946+
The expected alignment in bytes for the underlying LOAD/STORE operation.
947+
"""
948+
mapper = {
949+
# dtype==float => lowered with float4 => 4*4=16 bytes alignment;
950+
np.float32: 16,
951+
# dtype==half => lowered with float2 => 2*4=8 bytes alignment;
952+
np.float16: 8
953+
}
954+
try:
955+
return mapper[self.function.dtype]
956+
except KeyError:
957+
raise ValueError(
958+
f"Unsupported dtype `{self.function.dtype}` for VectorAccess"
959+
)
916960

917961
@property
918962
def indices(self):
919963
return self.base.indices
920964

921-
@cacheit
922-
def sort_key(self, order=None):
923-
# Ensure that the VectorAccess is sorted as the base
924-
return self.base.sort_key(order=order)
925-
926965

927966
# Some other utility objects
928967
Null = Macro('NULL')

devito/types/parallel.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
from functools import cached_property
1212

1313
import numpy as np
14-
from sympy import Expr
1514

1615
from devito.exceptions import InvalidArgument
1716
from devito.parameters import configuration
18-
from devito.symbolics import Reserved, Terminal, search
17+
from devito.symbolics import AlignedAccess, Terminal, search
1918
from devito.tools import as_list, as_tuple, is_integer
2019
from devito.types.array import Array, ArrayObject
2120
from devito.types.basic import Scalar, Symbol
@@ -403,7 +402,7 @@ def __init_finalize__(self, *args, **kwargs):
403402
super().__init_finalize__(*args, **kwargs)
404403

405404

406-
class TensorMove(Expr, Reserved, Terminal):
405+
class TensorMove(AlignedAccess, Terminal):
407406

408407
"""
409408
Represent the LOAD/STORE of a multi-dimensional block of data from/to a higher
@@ -423,17 +422,12 @@ class TensorMove(Expr, Reserved, Terminal):
423422

424423
_expected_alignment = 16
425424
"""
426-
The expected alignment in bytes for the accessed vector. This must be
427-
honored by the compiler for correctness.
425+
The expected alignment in bytes for the underlying LOAD/STORE operation.
428426
"""
429427

430428
def __new__(cls, base, tid0, coords, **kwargs):
431429
return super().__new__(cls, base, tid0, coords)
432430

433-
@property
434-
def base(self):
435-
return self.args[0]
436-
437431
@property
438432
def tid0(self):
439433
return self.args[1]
@@ -442,10 +436,6 @@ def tid0(self):
442436
def coords(self):
443437
return self.args[2]
444438

445-
@property
446-
def function(self):
447-
return self.base.function
448-
449439
@cached_property
450440
def indexed(self):
451441
return self.function[self.coords]
@@ -454,9 +444,5 @@ def indexed(self):
454444
def ndim(self):
455445
return self.function.ndim
456446

457-
func = Reserved._rebuild
458-
459447
def _ccode(self, printer):
460448
return str(self)
461-
462-
_sympystr = _ccode

0 commit comments

Comments
 (0)