Skip to content

Commit 143f99b

Browse files
committed
Account non-contiguous arys in PyOpenCLActx.to_numpy.
1 parent b30d74e commit 143f99b

2 files changed

Lines changed: 72 additions & 2 deletions

File tree

arraycontext/impl/pyopencl/__init__.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
import numpy as np
3838
from typing_extensions import Self, override
3939

40+
from pytools import memoize_method
41+
4042
from arraycontext.container.traversal import (
4143
rec_map_array_container,
4244
rec_map_container,
@@ -62,7 +64,7 @@
6264
if TYPE_CHECKING:
6365
from collections.abc import Callable, Mapping
6466

65-
from numpy.typing import NDArray
67+
from numpy.typing import DTypeLike, NDArray
6668

6769
import loopy as lp
6870
import pyopencl as cl
@@ -263,12 +265,64 @@ def to_numpy(self, array: Array) -> np.ndarray:
263265
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
264266
...
265267

268+
@memoize_method
269+
def _get_to_numpy_contiguous_copy_kernel(
270+
self, dtype: DTypeLike, ndim: int
271+
) -> lp.TranslationUnit:
272+
import loopy as lp
273+
274+
from arraycontext import make_loopy_program
275+
276+
inames = tuple(f"i{iaxis}" for iaxis in range(ndim))
277+
shape_names = tuple(f"n{iaxis}" for iaxis in range(ndim))
278+
domain = (
279+
"{ ["
280+
+ ", ".join(inames)
281+
+ "] : "
282+
+ " and ".join(
283+
f"0 <= {iname} < {shape_name}"
284+
for iname, shape_name in zip(inames, shape_names, strict=True)
285+
)
286+
+ " }"
287+
)
288+
289+
indices = ", ".join(inames)
290+
t_unit = make_loopy_program(
291+
[domain],
292+
[f"out[{indices}] = inp[{indices}]"],
293+
kernel_data=[
294+
lp.GlobalArg("out", dtype=dtype, shape=lp.auto),
295+
lp.GlobalArg(
296+
"inp",
297+
dtype=dtype,
298+
strides=tuple(f"s{i}" for i in range(ndim)),
299+
shape=lp.auto,
300+
),
301+
lp.ValueArg(",".join([f"s{i}" for i in range(ndim)]), dtype=np.int64),
302+
lp.ValueArg(",".join([f"n{i}" for i in range(ndim)]), dtype=np.int64),
303+
],
304+
name=f"to_numpy_contiguous_copy_{ndim}d",
305+
)
306+
return t_unit
307+
266308
@override
267309
def to_numpy(self,
268310
array: ArrayOrContainerOrScalar
269311
) -> NumpyOrContainerOrScalar:
270312
def _to_numpy(ary):
271-
return ary.get(queue=self.queue)
313+
if ary.flags.forc:
314+
# pyopenclsupports host transfers only for contiguous arrays.
315+
return ary.get(queue=self.queue)
316+
317+
result = self.call_loopy(
318+
self._get_to_numpy_contiguous_copy_kernel(ary.dtype, ary.ndim),
319+
inp=ary,
320+
**{
321+
f"s{i}": stride // ary.dtype.itemsize
322+
for i, stride in enumerate(ary.strides)
323+
},
324+
)["out"]
325+
return result.get(queue=self.queue)
272326

273327
return with_array_context(
274328
self._rec_map_container(_to_numpy, array),

test/test_arraycontext.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,22 @@ def test_linspace(actx_factory: ArrayContextFactory, args, kwargs):
16591659
assert np.allclose(actx_linspace, np_linspace)
16601660

16611661

1662+
# {{{ test_to_numpy_transpose
1663+
1664+
def test_to_numpy_transpose(actx_factory: ArrayContextFactory):
1665+
actx = actx_factory()
1666+
rng = np.random.default_rng()
1667+
np_ary = rng.random((256, 256, 256))
1668+
ary = actx.from_numpy(np_ary)
1669+
axis_perm = (0, 2, 1)
1670+
1671+
np.testing.assert_allclose(
1672+
actx.to_numpy(actx.np.transpose(ary, axis_perm)),
1673+
np.transpose(np_ary, axis_perm))
1674+
1675+
# }}}
1676+
1677+
16621678
if __name__ == "__main__":
16631679
import sys
16641680
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)