Skip to content

Commit 4503028

Browse files
committed
compiler: fix linearization padding check
1 parent 5d07778 commit 4503028

2 files changed

Lines changed: 32 additions & 14 deletions

File tree

devito/passes/iet/linearization.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,7 @@ def key1(f, d):
7070
if f.is_regular:
7171
# For paddable objects the following holds:
7272
# `same dim + same halo + same padding_dtype => same (auto-)padding`
73-
if d is f.dimensions[-1]:
74-
# Only the last dimension is padded
75-
try:
76-
if f.padding == f.mapped.padding:
77-
# Padding set from the mapped Function
78-
# e.g. from buffering or fft temp array
79-
pad_key = f.mapped.__padding_dtype__
80-
else:
81-
pad_key = f.__padding_dtype__
82-
except AttributeError:
83-
pad_key = f.__padding_dtype__
84-
else:
85-
pad_key = None
73+
pad_key = f.__padding_dtype__ if d is f.dimensions[-1] else None
8674

8775
return (d, f._size_halo[d], pad_key)
8876
else:

devito/types/array.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,37 @@ class MappedArrayMixin:
247247

248248

249249
class ArrayMapped(MappedArrayMixin, Array):
250-
is_autopaddable = True
250+
251+
__rkwargs__ = Array.__rkwargs__ + ('mapped',)
252+
253+
def __init_finalize__(self, *args, **kwargs):
254+
self._mapped = kwargs.get('mapped')
255+
super().__init_finalize__(*args, **kwargs)
256+
257+
@property
258+
def mapped(self):
259+
return self._mapped
260+
261+
@property
262+
def is_autopaddable(self):
263+
if self.mapped is None:
264+
return True
265+
return self.mapped.is_autopaddable
266+
267+
@property
268+
def __padding_dtype__(self):
269+
if self.mapped is not None:
270+
return self.mapped.__padding_dtype__
271+
return super().__padding_dtype__
272+
273+
@cached_property
274+
def _signature(self):
275+
# Exclude `mapped` so buf-reuse can dedup buffers across distinct
276+
# mapped Functions
277+
ret = [type(self), self.indices]
278+
attrs = set(self.__rkwargs__) - {'name', 'function', 'mapped'}
279+
ret.extend(getattr(self, i) for i in attrs)
280+
return frozenset(ret)
251281

252282

253283
class ArrayObject(ArrayBasic):

0 commit comments

Comments
 (0)