Skip to content

Commit 836b318

Browse files
committed
Account non-contiguous arys in PyOpenCLActx.to_numpy.
1 parent 3565340 commit 836b318

2 files changed

Lines changed: 85 additions & 2 deletions

File tree

arraycontext/impl/pyopencl/__init__.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
import numpy as np
3838
from typing_extensions import Self, override
3939

40+
from pytools import memoize_method
41+
4042
from arraycontext.container.traversal import (
4143
rec_map_array_container,
4244
rec_map_container,
@@ -62,7 +64,7 @@
6264
if TYPE_CHECKING:
6365
from collections.abc import Callable, Mapping
6466

65-
from numpy.typing import NDArray
67+
from numpy.typing import DTypeLike, NDArray
6668

6769
import loopy as lp
6870
import pyopencl as cl
@@ -263,12 +265,74 @@ def to_numpy(self, array: Array) -> np.ndarray:
263265
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
264266
...
265267

268+
@memoize_method
269+
def _get_to_numpy_noncontiguous_copy_kernel(
270+
self, dtype: DTypeLike, ndim: int
271+
) -> lp.TranslationUnit:
272+
"""
273+
Returns a translation unit containing a loopy kernel that:
274+
275+
- Accepts a PyOpenCL array ``inp`` with per-axis strides exposed as
276+
``s0, s1, ..., s{ndim-1}``.
277+
- Produces a contiguous, row-major (C-order) output array ``out`` of
278+
the same shape, with elements copied from the corresponding
279+
coordinates in ``inp``.
280+
"""
281+
import loopy as lp
282+
283+
from arraycontext import make_loopy_program
284+
285+
inames = tuple(f"i{iaxis}" for iaxis in range(ndim))
286+
shape_names = tuple(f"n{iaxis}" for iaxis in range(ndim))
287+
domain = (
288+
"{ ["
289+
+ ", ".join(inames)
290+
+ "] : "
291+
+ " and ".join(
292+
f"0 <= {iname} < {shape_name}"
293+
for iname, shape_name in zip(inames, shape_names, strict=True)
294+
)
295+
+ " }"
296+
)
297+
298+
indices = ", ".join(inames)
299+
t_unit = make_loopy_program(
300+
[domain],
301+
[f"out[{indices}] = inp[{indices}]"],
302+
kernel_data=[
303+
lp.GlobalArg("out", dtype=dtype, shape=lp.auto),
304+
lp.GlobalArg(
305+
"inp",
306+
dtype=dtype,
307+
strides=tuple(f"s{i}" for i in range(ndim)),
308+
offset=lp.auto,
309+
shape=lp.auto,
310+
),
311+
lp.ValueArg(",".join([f"s{i}" for i in range(ndim)]), dtype=np.int64),
312+
lp.ValueArg(",".join([f"n{i}" for i in range(ndim)]), dtype=np.int64),
313+
],
314+
name=f"to_numpy_noncontiguous_copy_{ndim}d",
315+
)
316+
return t_unit
317+
266318
@override
267319
def to_numpy(self,
268320
array: ArrayOrContainerOrScalar
269321
) -> NumpyOrContainerOrScalar:
270322
def _to_numpy(ary):
271-
return ary.get(queue=self.queue)
323+
if ary.flags.forc:
324+
# pyopencl supports host transfers only for contiguous arrays.
325+
return ary.get(queue=self.queue)
326+
327+
result = self.call_loopy(
328+
self._get_to_numpy_noncontiguous_copy_kernel(ary.dtype, ary.ndim),
329+
inp=ary,
330+
**{
331+
f"s{i}": stride // ary.dtype.itemsize
332+
for i, stride in enumerate(ary.strides)
333+
},
334+
)["out"]
335+
return result.get(queue=self.queue)
272336

273337
return with_array_context(
274338
self._rec_map_container(_to_numpy, array),

test/test_arraycontext.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,25 @@ def test_linspace(actx_factory: ArrayContextFactory, args, kwargs):
16591659
assert np.allclose(actx_linspace, np_linspace)
16601660

16611661

1662+
# {{{ test_to_numpy_transpose
1663+
1664+
def test_to_numpy_transpose(actx_factory: ArrayContextFactory):
1665+
# fails prior to <https://github.com/inducer/arraycontext/pull/357> for
1666+
# pyopencl actx -- cl_array.Array.transpose generates non-contiguous
1667+
# arrays requiring non-trivial logic for to host copies.
1668+
actx = actx_factory()
1669+
rng = np.random.default_rng()
1670+
np_ary = rng.random((256, 256, 256))
1671+
ary = actx.from_numpy(np_ary)
1672+
axis_perm = (0, 2, 1)
1673+
1674+
np.testing.assert_allclose(
1675+
actx.to_numpy(actx.np.transpose(ary, axis_perm)),
1676+
np.transpose(np_ary, axis_perm))
1677+
1678+
# }}}
1679+
1680+
16621681
if __name__ == "__main__":
16631682
import sys
16641683
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)