Skip to content

Commit 420ecff

Browse files
committed
compiler: Permute Data._decomposition and _modulo on transpose
The numpy transpose primitive only permutes the buffer layout; the companion metadata that Data uses to translate global<->local indices under MPI (Data._decomposition and Data._modulo) was left in the pre-transpose axis order, so any indexing on a transposed Data either silently used the wrong decomposition slice or raised an IndexError that depends on the run's rank topology. Permute both metadata structures alongside the buffer in Data.transpose / Data.swapaxes via the same axes argument so they stay consistent with the new axis order, and add a regression test exercising 2D and 3D transposes against the equivalent numpy.ndarray.transpose result. Fixes #2187
1 parent dc3fa07 commit 420ecff

2 files changed

Lines changed: 123 additions & 0 deletions

File tree

devito/data/data.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,57 @@ def __repr__(self):
208208
def __str__(self):
209209
return super(Data, self._local).__str__()
210210

211+
def transpose(self, *axes):
212+
"""
213+
Return a view of ``self`` with permuted axes.
214+
215+
Overridden so that ``_decomposition``, ``_modulo`` (and the convenience
216+
flag ``_is_distributed``) are permuted to match the new axis ordering,
217+
rather than copied verbatim from ``self`` as ``__array_finalize__``
218+
would otherwise leave them. Without this, a subsequent slice on the
219+
transposed view (e.g. ``f.data.T[::2, ::2]``) is computed against the
220+
wrong per-axis decomposition and silently returns a wrong-shaped
221+
result (see issue #2187).
222+
"""
223+
# Accept the same axis-spec forms as ``numpy.ndarray.transpose``:
224+
# no args, a single ``None``, a single tuple/list, or per-arg.
225+
if len(axes) == 1:
226+
axes = as_tuple(axes[0])
227+
new_order = (
228+
tuple(range(self.ndim - 1, -1, -1)) if not axes
229+
else tuple(ax % self.ndim for ax in axes)
230+
)
231+
232+
ret = super().transpose(*axes)
233+
ret._decomposition = tuple(self._decomposition[i] for i in new_order)
234+
ret._modulo = tuple(self._modulo[i] for i in new_order)
235+
ret._is_distributed = any(d is not None for d in ret._decomposition)
236+
return ret
237+
238+
def swapaxes(self, axis1, axis2):
239+
"""
240+
Return a view of ``self`` with ``axis1`` and ``axis2`` swapped, with
241+
``_decomposition`` / ``_modulo`` swapped in the same way (see
242+
:meth:`transpose`).
243+
"""
244+
axis1 = axis1 % self.ndim
245+
axis2 = axis2 % self.ndim
246+
ret = super().swapaxes(axis1, axis2)
247+
order = list(range(self.ndim))
248+
order[axis1], order[axis2] = order[axis2], order[axis1]
249+
ret._decomposition = tuple(self._decomposition[i] for i in order)
250+
ret._modulo = tuple(self._modulo[i] for i in order)
251+
ret._is_distributed = any(d is not None for d in ret._decomposition)
252+
return ret
253+
254+
@property
255+
def T(self):
256+
"""
257+
The transposed array. Overridden so the C-level ``ndarray.T`` shortcut
258+
also permutes the per-axis metadata (see :meth:`transpose`).
259+
"""
260+
return self.transpose()
261+
211262
@_check_idx
212263
def __getitem__(self, glb_idx, comm_type, gather_rank=None):
213264
loc_idx = self._index_glb_to_loc(glb_idx)

tests/test_data.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,78 @@ def test_indexing_into_sparse(self):
211211
sf.data[1:-1, 0] = np.arange(8)
212212
assert np.all(sf.data[1:-1, 0] == np.arange(8))
213213

214+
def test_slice_after_transpose(self):
215+
"""
216+
Slicing a ``Data`` view that has been transposed (via ``.T``,
217+
``transpose`` or ``swapaxes``) must use the new axis ordering for
218+
per-axis metadata. Previously the metadata was copied verbatim from
219+
the un-transposed array, so a subsequent slice was computed against
220+
the wrong decomposition and silently returned a wrong-shaped result
221+
(see issue #2187).
222+
"""
223+
grid = Grid(shape=(4, 6))
224+
f = Function(name='f', grid=grid)
225+
f.data[:] = np.arange(24).reshape((4, 6)).astype(np.float32)
226+
ref = np.array(f.data)
227+
228+
# ``.T`` (C-level shortcut) then slice
229+
assert np.array_equal(f.data.T[::2, ::2], ref.T[::2, ::2])
230+
231+
# Equivalent: slice then ``.T``
232+
assert np.array_equal(f.data[::2, ::2].T, ref[::2, ::2].T)
233+
234+
# Explicit ``transpose`` call -- same behavior as ``.T``
235+
assert np.array_equal(f.data.transpose()[::2, ::2],
236+
ref.transpose()[::2, ::2])
237+
238+
# ``swapaxes`` between non-conforming dims
239+
assert np.array_equal(f.data.swapaxes(0, 1)[::2, ::2],
240+
ref.swapaxes(0, 1)[::2, ::2])
241+
242+
# 3D transpose with an explicit axis order, then per-axis slice
243+
grid3 = Grid(shape=(2, 4, 6))
244+
g = Function(name='g3', grid=grid3)
245+
g.data[:] = np.arange(48).reshape((2, 4, 6)).astype(np.float32)
246+
ref3 = np.array(g.data)
247+
248+
assert np.array_equal(g.data.T[::2, ::2, ::2], ref3.T[::2, ::2, ::2])
249+
assert np.array_equal(g.data.transpose((1, 0, 2))[::2, ::1, ::3],
250+
ref3.transpose((1, 0, 2))[::2, ::1, ::3])
251+
252+
def test_transpose_permutes_data_metadata(self):
253+
"""
254+
After a transpose-like operation, ``_decomposition`` and ``_modulo``
255+
must be permuted to match the new axis order so that subsequent
256+
``__getitem__`` translations use the right per-axis ranges.
257+
"""
258+
grid = Grid(shape=(4, 6))
259+
f = Function(name='f', grid=grid)
260+
261+
original_decomp = f.data._decomposition
262+
assert len(original_decomp) == 2
263+
264+
# ``.T`` reverses everything
265+
tdata = f.data.T
266+
assert tdata._decomposition == original_decomp[::-1]
267+
assert tdata._modulo == f.data._modulo[::-1]
268+
269+
# ``transpose()`` with no args == ``.T``
270+
tdata2 = f.data.transpose()
271+
assert tdata2._decomposition == original_decomp[::-1]
272+
273+
# ``swapaxes`` swaps the two named axes
274+
sdata = f.data.swapaxes(0, 1)
275+
assert sdata._decomposition == (original_decomp[1], original_decomp[0])
276+
277+
# Explicit axis-order
278+
grid3 = Grid(shape=(2, 4, 6))
279+
g = Function(name='g3', grid=grid3)
280+
gdec = g.data._decomposition
281+
perm = g.data.transpose((1, 2, 0))
282+
assert perm._decomposition == (gdec[1], gdec[2], gdec[0])
283+
assert perm._modulo == (g.data._modulo[1], g.data._modulo[2],
284+
g.data._modulo[0])
285+
214286
@pytest.mark.parallel(mode=1)
215287
def test_indexing_into_sparse_subfunc_singlempi(self, mode):
216288
grid = Grid(shape=(4, 4))

0 commit comments

Comments
 (0)