Skip to content

Commit ac5af77

Browse files
committed
review: apply @mloubout's suggestions
* use ``devito.tools.as_tuple`` to normalize the polymorphic ``axes`` of ``transpose`` (no-args, single ``None``, single tuple/list, or individual positional args all collapse to a flat tuple). * drop the ``isinstance(ret, Data)`` guards in ``transpose`` and ``swapaxes`` -- ``numpy.ndarray.<view-op>`` preserves the subclass via ``__array_finalize__``, so ``ret`` is always a ``Data``. * replace the hand-built reference in ``test_slice_after_transpose`` with ``ref = np.array(f.data)`` (and the 3D equivalent for ``g``), removing the duplicated ``arange/reshape/astype`` lines; shape and values are still compared against the NumPy reference array via ``np.array_equal``.
1 parent b72317a commit ac5af77

2 files changed

Lines changed: 27 additions & 42 deletions

File tree

devito/data/data.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -220,20 +220,19 @@ def transpose(self, *axes):
220220
wrong per-axis decomposition and silently returns a wrong-shaped
221221
result (see issue #2187).
222222
"""
223-
if not axes or axes == (None,):
224-
new_order = tuple(range(self.ndim - 1, -1, -1))
225-
elif len(axes) == 1 and isinstance(axes[0], Iterable):
226-
new_order = tuple(axes[0])
227-
else:
228-
new_order = tuple(axes)
229-
# Normalize negative axes (numpy accepts them)
230-
new_order = tuple(ax % self.ndim for ax in new_order)
223+
# Accept the same axis-spec forms as ``numpy.ndarray.transpose``:
224+
# no args, a single ``None``, a single tuple/list, or per-arg.
225+
if len(axes) == 1:
226+
axes = as_tuple(axes[0])
227+
new_order = (
228+
tuple(range(self.ndim - 1, -1, -1)) if not axes
229+
else tuple(ax % self.ndim for ax in axes)
230+
)
231231

232232
ret = super().transpose(*axes)
233-
if isinstance(ret, Data):
234-
ret._decomposition = tuple(self._decomposition[i] for i in new_order)
235-
ret._modulo = tuple(self._modulo[i] for i in new_order)
236-
ret._is_distributed = any(d is not None for d in ret._decomposition)
233+
ret._decomposition = tuple(self._decomposition[i] for i in new_order)
234+
ret._modulo = tuple(self._modulo[i] for i in new_order)
235+
ret._is_distributed = any(d is not None for d in ret._decomposition)
237236
return ret
238237

239238
def swapaxes(self, axis1, axis2):
@@ -245,12 +244,11 @@ def swapaxes(self, axis1, axis2):
245244
axis1 = axis1 % self.ndim
246245
axis2 = axis2 % self.ndim
247246
ret = super().swapaxes(axis1, axis2)
248-
if isinstance(ret, Data):
249-
order = list(range(self.ndim))
250-
order[axis1], order[axis2] = order[axis2], order[axis1]
251-
ret._decomposition = tuple(self._decomposition[i] for i in order)
252-
ret._modulo = tuple(self._modulo[i] for i in order)
253-
ret._is_distributed = any(d is not None for d in ret._decomposition)
247+
order = list(range(self.ndim))
248+
order[axis1], order[axis2] = order[axis2], order[axis1]
249+
ret._decomposition = tuple(self._decomposition[i] for i in order)
250+
ret._modulo = tuple(self._modulo[i] for i in order)
251+
ret._is_distributed = any(d is not None for d in ret._decomposition)
254252
return ret
255253

256254
@property

tests/test_data.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -223,44 +223,31 @@ def test_slice_after_transpose(self):
223223
grid = Grid(shape=(4, 6))
224224
f = Function(name='f', grid=grid)
225225
f.data[:] = np.arange(24).reshape((4, 6)).astype(np.float32)
226-
227-
ref = np.arange(24).reshape((4, 6)).astype(np.float32)
226+
ref = np.array(f.data)
228227

229228
# ``.T`` (C-level shortcut) then slice
230-
out = f.data.T[::2, ::2]
231-
assert out.shape == ref.T[::2, ::2].shape
232-
assert np.array_equal(np.asarray(out), ref.T[::2, ::2])
229+
assert np.array_equal(f.data.T[::2, ::2], ref.T[::2, ::2])
233230

234231
# Equivalent: slice then ``.T``
235-
out2 = f.data[::2, ::2].T
236-
assert out2.shape == ref[::2, ::2].T.shape
237-
assert np.array_equal(np.asarray(out2), ref[::2, ::2].T)
232+
assert np.array_equal(f.data[::2, ::2].T, ref[::2, ::2].T)
238233

239234
# Explicit ``transpose`` call -- same behavior as ``.T``
240-
out3 = f.data.transpose()[::2, ::2]
241-
assert out3.shape == ref.transpose()[::2, ::2].shape
242-
assert np.array_equal(np.asarray(out3), ref.transpose()[::2, ::2])
235+
assert np.array_equal(f.data.transpose()[::2, ::2],
236+
ref.transpose()[::2, ::2])
243237

244238
# ``swapaxes`` between non-conforming dims
245-
out4 = f.data.swapaxes(0, 1)[::2, ::2]
246-
assert out4.shape == ref.swapaxes(0, 1)[::2, ::2].shape
247-
assert np.array_equal(np.asarray(out4), ref.swapaxes(0, 1)[::2, ::2])
239+
assert np.array_equal(f.data.swapaxes(0, 1)[::2, ::2],
240+
ref.swapaxes(0, 1)[::2, ::2])
248241

249242
# 3D transpose with an explicit axis order, then per-axis slice
250243
grid3 = Grid(shape=(2, 4, 6))
251244
g = Function(name='g3', grid=grid3)
252245
g.data[:] = np.arange(48).reshape((2, 4, 6)).astype(np.float32)
253-
ref3 = np.arange(48).reshape((2, 4, 6)).astype(np.float32)
254-
255-
out5 = g.data.T[::2, ::2, ::2]
256-
assert out5.shape == ref3.T[::2, ::2, ::2].shape
257-
assert np.array_equal(np.asarray(out5), ref3.T[::2, ::2, ::2])
246+
ref3 = np.array(g.data)
258247

259-
out6 = g.data.transpose((1, 0, 2))[::2, ::1, ::3]
260-
assert out6.shape == ref3.transpose((1, 0, 2))[::2, ::1, ::3].shape
261-
assert np.array_equal(
262-
np.asarray(out6), ref3.transpose((1, 0, 2))[::2, ::1, ::3]
263-
)
248+
assert np.array_equal(g.data.T[::2, ::2, ::2], ref3.T[::2, ::2, ::2])
249+
assert np.array_equal(g.data.transpose((1, 0, 2))[::2, ::1, ::3],
250+
ref3.transpose((1, 0, 2))[::2, ::1, ::3])
264251

265252
def test_transpose_permutes_data_metadata(self):
266253
"""

0 commit comments

Comments
 (0)