Skip to content

Commit e50a4a7

Browse files
Basic support for device arrays (#772)
* cupy compatibility Fix empty allocation Register cubed.Array in to_numpy Some tests of Array implementing the array API need to convert from the active backend to NumPy. * skip processes for cupy * array api passing * fixup * review * ignore cupy type annotations * try to handle jax setitem * more jax fixes
1 parent 735c80e commit e50a4a7

9 files changed

Lines changed: 121 additions & 9 deletions

File tree

cubed/_testing.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import functools
2+
import importlib.util
3+
4+
import numpy as np
5+
import numpy.testing as npt
6+
7+
import cubed
8+
9+
10+
@functools.cache
11+
def has_cupy() -> bool:
12+
return importlib.util.find_spec("cupy") is not None
13+
14+
15+
@functools.singledispatch
16+
def to_numpy(a):
17+
return np.asarray(a)
18+
19+
20+
@to_numpy.register(cubed.Array)
21+
def _(a: cubed.Array) -> np.ndarray:
22+
return to_numpy(a.compute())
23+
24+
25+
if has_cupy():
26+
import cupy
27+
28+
@to_numpy.register(cupy.ndarray)
29+
def _(a):
30+
return a.get()
31+
32+
33+
def assert_array_equal(a, b, **kwargs):
34+
a = to_numpy(a)
35+
b = to_numpy(b)
36+
npt.assert_array_equal(a, b, **kwargs)
37+
38+
39+
def assert_allclose(a, b, **kwargs):
40+
a = to_numpy(a)
41+
b = to_numpy(b)
42+
npt.assert_allclose(a, b, **kwargs)

cubed/array_api/manipulation_functions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from toolz import reduce
77

88
from cubed.array_api.creation_functions import empty
9+
from cubed.backend_array_api import IS_IMMUTABLE_ARRAY
910
from cubed.backend_array_api import namespace as nxp
1011
from cubed.core import (
1112
blockwise,
@@ -213,14 +214,17 @@ def _read_concat_chunk(
213214
stop = min(stop, target_shape[axis])
214215

215216
chunk_shape = tuple(ch[bi] for ch, bi in zip(target_chunks, block_id))
216-
out = np.empty(chunk_shape, dtype=dtype)
217+
out = nxp.empty(chunk_shape, dtype=dtype)
217218
for array, (lchunk_selection, lout_selection) in zip(
218219
arrays,
219220
_chunk_slices(
220221
offsets, start, stop, target_chunks, chunksize, in_shapes, axis, block_id
221222
),
222223
):
223-
out[lout_selection] = array[lchunk_selection]
224+
if IS_IMMUTABLE_ARRAY:
225+
out = out.at[lout_selection].set(array[lchunk_selection])
226+
else:
227+
out[lout_selection] = array[lchunk_selection]
224228
return out
225229

226230

cubed/backend_array_api.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,32 @@
3232
import array_api_compat.numpy
3333

3434
namespace = array_api_compat.numpy
35+
xp_name = "numpy"
3536

3637

3738
# These functions to convert to/from backend arrays
3839
# assume that no extra memory is allocated, by using the
3940
# Python buffer protocol.
4041
# See https://data-apis.org/array-api/latest/API_specification/generated/array_api.asarray.html
42+
if "cupy" in namespace.__name__:
43+
# zarr-python 3.x natively supports some device buffers (currently just cupy,
44+
# but https://github.com/zarr-developers/zarr-python/issues/2658 is expanding the
45+
# set). For these backends, we *don't* want to copy to the host.
4146

47+
def backend_array_to_numpy_array(arr):
48+
return arr
4249

43-
def backend_array_to_numpy_array(arr):
44-
return np.asarray(arr)
50+
else:
51+
52+
def backend_array_to_numpy_array(arr):
53+
return np.asarray(arr)
4554

4655

4756
def numpy_array_to_backend_array(arr, *, dtype=None):
4857
if isinstance(arr, dict):
4958
return {k: namespace.asarray(v, dtype=dtype) for k, v in arr.items()}
5059
return namespace.asarray(arr, dtype=dtype)
60+
61+
62+
# jax doesn't support in-place assignment, so we use .at[].set() instead.
63+
IS_IMMUTABLE_ARRAY = "jax" in xp_name

cubed/core/ops.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from toolz import map
1515

1616
from cubed import config
17+
from cubed.backend_array_api import IS_IMMUTABLE_ARRAY, numpy_array_to_backend_array
1718
from cubed.backend_array_api import namespace as nxp
18-
from cubed.backend_array_api import numpy_array_to_backend_array
1919
from cubed.core.array import CoreArray, check_array_specs, compute, gensym
2020
from cubed.core.plan import Plan, new_temp_path
2121
from cubed.primitive.blockwise import blockwise as primitive_blockwise
@@ -559,14 +559,17 @@ def _assemble_index_chunk(
559559
indexer = _create_zarr_indexer(in_sel, in_shape, in_chunksize)
560560

561561
shape = indexer.shape
562-
out = np.empty(shape, dtype=dtype)
562+
out = nxp.empty(shape, dtype=dtype)
563563

564564
if array_size(shape) > 0:
565565
_, lchunk_selection, lout_selection, *_ = zip(*indexer)
566566
for ai, chunk_select, out_select in zip(
567567
arrays, lchunk_selection, lout_selection
568568
):
569-
out[out_select] = ai[chunk_select]
569+
if IS_IMMUTABLE_ARRAY:
570+
out = out.at[out_select].set(ai[chunk_select])
571+
else:
572+
out[out_select] = ai[chunk_select]
570573

571574
if func is not None:
572575
if has_keyword(func, "block_id"):

cubed/tests/test_array_api.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import fsspec
22
import numpy as np
33
import pytest
4-
from numpy.testing import assert_allclose, assert_array_equal
54

65
import cubed
76
import cubed.array_api as xp
7+
from cubed._testing import assert_allclose, assert_array_equal
88
from cubed.array_api.manipulation_functions import reshape_chunks
9-
from cubed.tests.utils import ALL_EXECUTORS, MAIN_EXECUTORS, MODAL_EXECUTORS
9+
from cubed.backend_array_api import namespace as nxp
10+
from cubed.tests.utils import (
11+
ALL_EXECUTORS,
12+
MAIN_EXECUTORS,
13+
MODAL_EXECUTORS,
14+
skip_if_cupy,
15+
)
1016

1117

1218
@pytest.fixture
@@ -20,6 +26,8 @@ def spec(tmp_path):
2026
ids=[executor.name for executor in MAIN_EXECUTORS],
2127
)
2228
def executor(request):
29+
if request.param.name == "processes" and "cupy" in nxp.__name__:
30+
pytest.skip(reason="CuPy is not supported with 'processes' executor")
2331
return request.param
2432

2533

@@ -29,6 +37,8 @@ def executor(request):
2937
ids=[executor.name for executor in ALL_EXECUTORS],
3038
)
3139
def any_executor(request):
40+
if request.param.name == "processes" and "cupy" in nxp.__name__:
41+
pytest.skip(reason="CuPy is not supported with 'processes' executor")
3242
return request.param
3343

3444

@@ -384,6 +394,7 @@ def test_index_slice_unsupported_step(spec):
384394

385395

386396
@pytest.mark.parametrize("axis", [0, 1])
397+
@skip_if_cupy # ndindex with a cupy.ndarray
387398
def test_take(spec, axis):
388399
a = xp.asarray(
389400
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],

cubed/tests/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66

77
from cubed import config
8+
from cubed.backend_array_api import namespace as nxp
89
from cubed.runtime.create import create_executor
910
from cubed.runtime.types import Callback
1011
from cubed.storage.backend import open_backend_array
@@ -125,3 +126,14 @@ def execute_pipeline(pipeline, executor):
125126
dag = nx.MultiDiGraph()
126127
dag.add_node("node", pipeline=pipeline)
127128
executor.execute_dag(dag)
129+
130+
131+
try:
132+
import pytest
133+
except ImportError:
134+
pass
135+
else:
136+
skip_if_cupy = pytest.mark.skipif(
137+
"cupy" in nxp.__name__,
138+
reason="CuPy is not supported",
139+
)

docs/user-guide/gpus.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# GPU Support
2+
3+
Cubed has experimental support for using GPU-backed ndarrays. With zarr-python's
4+
[native GPU support], you can load data into GPU memory, perform some cubed
5+
computation on the GPU, and write the result, while minimizing the number of host
6+
to device transfers.
7+
8+
```{note}
9+
Currently only NVIDIA GPUs and [CuPy] arrays are supported.
10+
```
11+
12+
Set the following environment variables to control whether host or device arrays
13+
are in Cubed and Zarr.
14+
15+
```shell
16+
# syntax may differ in your shell
17+
export CUBED_BACKEND_ARRAY_API_MODULE="array_api_compat.cupy"
18+
export ZARR_BUFFER="zarr.buffer.gpu.Buffer"
19+
export ZARR_NDBUFFER="zarr.buffer.gpu.NDBuffer"
20+
```
21+
22+
23+
[native GPU support]: https://zarr.readthedocs.io/en/stable/user-guide/gpu.html
24+
[CuPy]: https://cupy.dev/

docs/user-guide/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ reliability
1313
optimization
1414
scaling
1515
diagnostics
16+
gpus
1617
```

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ ignore_missing_imports = True
1919
ignore_missing_imports = True
2020
[mypy-coiled.*]
2121
ignore_missing_imports = True
22+
[mypy-cupy.*]
23+
ignore_missing_imports = True
2224
[mypy-dask.*]
2325
ignore_missing_imports = True
2426
[mypy-donfig.*]

0 commit comments

Comments
 (0)