Skip to content

Commit 766d2dc

Browse files
committed
Account non-contiguous arys in PyOpenCLActx.to_numpy.
1 parent 3565340 commit 766d2dc

2 files changed

Lines changed: 65 additions & 2 deletions

File tree

arraycontext/impl/pyopencl/__init__.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
import numpy as np
3838
from typing_extensions import Self, override
3939

40+
from pytools import memoize_method
41+
4042
from arraycontext.container.traversal import (
4143
rec_map_array_container,
4244
rec_map_container,
@@ -62,7 +64,7 @@
6264
if TYPE_CHECKING:
6365
from collections.abc import Callable, Mapping
6466

65-
from numpy.typing import NDArray
67+
from numpy.typing import DTypeLike, NDArray
6668

6769
import loopy as lp
6870
import pyopencl as cl
@@ -263,12 +265,54 @@ def to_numpy(self, array: Array) -> np.ndarray:
263265
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
264266
...
265267

268+
@memoize_method
269+
def _get_to_numpy_noncontiguous_copy_kernel(
270+
self, dtype: DTypeLike, ndim: int
271+
) -> lp.TranslationUnit:
272+
"""
273+
Returns a translation unit containing a loopy kernel that:
274+
275+
- Accepts a PyOpenCL array ``inp`` with per-axis strides exposed as
276+
``s0, s1, ..., s{ndim-1}``.
277+
- Produces a contiguous, row-major (C-order) output array ``output`` of
278+
the same shape, with elements copied from the corresponding
279+
coordinates in ``input``.
280+
"""
281+
282+
import loopy as lp
283+
284+
from arraycontext.loopy import _DEFAULT_LOOPY_OPTIONS
285+
286+
t_unit = lp.make_copy_kernel(
287+
["c"] * ndim, [f"stride:s{i}" for i in range(ndim)]
288+
)
289+
t_unit = lp.add_dtypes(t_unit, {"input": dtype})
290+
new_args = [
291+
*t_unit.default_entrypoint.args,
292+
*[lp.ValueArg(f"s{i}", dtype=np.uint64) for i in range(ndim)],
293+
]
294+
t_unit = t_unit.with_kernel(t_unit.default_entrypoint.copy(args=new_args))
295+
t_unit = lp.set_options(t_unit, _DEFAULT_LOOPY_OPTIONS)
296+
return t_unit
297+
266298
@override
267299
def to_numpy(self,
268300
array: ArrayOrContainerOrScalar
269301
) -> NumpyOrContainerOrScalar:
270302
def _to_numpy(ary):
271-
return ary.get(queue=self.queue)
303+
if ary.flags.forc:
304+
# pyopencl supports host transfers only for contiguous arrays.
305+
return ary.get(queue=self.queue)
306+
307+
result = self.call_loopy(
308+
self._get_to_numpy_noncontiguous_copy_kernel(ary.dtype, ary.ndim),
309+
input=ary,
310+
**{
311+
f"s{i}": stride // ary.dtype.itemsize
312+
for i, stride in enumerate(ary.strides)
313+
},
314+
)["output"]
315+
return result.get(queue=self.queue)
272316

273317
return with_array_context(
274318
self._rec_map_container(_to_numpy, array),

test/test_arraycontext.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,25 @@ def test_linspace(actx_factory: ArrayContextFactory, args, kwargs):
16591659
assert np.allclose(actx_linspace, np_linspace)
16601660

16611661

1662+
# {{{ test_to_numpy_transpose
1663+
1664+
def test_to_numpy_transpose(actx_factory: ArrayContextFactory):
1665+
# fails prior to <https://github.com/inducer/arraycontext/pull/357> for
1666+
# pyopencl actx -- cl_array.Array.transpose generates non-contiguous
1667+
# arrays requiring non-trivial logic for to host copies.
1668+
actx = actx_factory()
1669+
rng = np.random.default_rng()
1670+
np_ary = rng.random((256, 256, 256))
1671+
ary = actx.from_numpy(np_ary)
1672+
axis_perm = (0, 2, 1)
1673+
1674+
np.testing.assert_allclose(
1675+
actx.to_numpy(actx.np.transpose(ary, axis_perm)),
1676+
np.transpose(np_ary, axis_perm))
1677+
1678+
# }}}
1679+
1680+
16621681
if __name__ == "__main__":
16631682
import sys
16641683
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)