Skip to content

Commit 76f7560

Browse files
committed
added a matching test
1 parent dd69543 commit 76f7560

File tree

3 files changed

+33
-80
lines changed

3 files changed

+33
-80
lines changed

src/zarr/codecs/gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _convert_from_nvcomp_arrays(
9696
chunks_and_specs: Iterable[tuple[Buffer | None, ArraySpec]],
9797
) -> Iterable[Buffer | None]:
9898
return [
99-
spec.prototype.buffer.from_array_like(cp.array(a, dtype=np.dtype("b"), copy=False))
99+
spec.prototype.buffer.from_array_like(cp.array(a, dtype=np.dtype("B"), copy=False))
100100
if a
101101
else None
102102
for a, (_, spec) in zip(arrays, chunks_and_specs, strict=True)

tests/test_api.py

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from zarr.storage._common import StorePath
1111

1212
if TYPE_CHECKING:
13-
from collections.abc import Callable, Iterator
13+
from collections.abc import Callable
1414
from pathlib import Path
1515

1616
from zarr.abc.codec import Codec
@@ -1257,7 +1257,7 @@ def test_api_exports() -> None:
12571257
assert zarr.api.asynchronous.__all__ == zarr.api.synchronous.__all__
12581258

12591259

1260-
@gpu_test # type: ignore[misc]
1260+
@gpu_test # type: ignore[misc,unused-ignore]
12611261
@pytest.mark.parametrize(
12621262
"store",
12631263
["local", "memory", "zip"],
@@ -1295,46 +1295,6 @@ def test_gpu_basic(store: Store, zarr_format: ZarrFormat | None, codec: str | Co
12951295
cp.testing.assert_array_equal(result, src[:10, :10])
12961296

12971297

1298-
@gpu_test # type: ignore[misc]
1299-
@pytest.mark.parametrize("host_encode", [True, False])
1300-
def test_gpu_codec_compatibility(host_encode: bool) -> None:
1301-
# Ensure that the we can decode CPU-encoded data with the GPU
1302-
# and GPU-encoded data with the CPU
1303-
import cupy as cp
1304-
1305-
@contextlib.contextmanager
1306-
def gpu_context() -> Iterator[None]:
1307-
with zarr.config.enable_gpu():
1308-
yield
1309-
1310-
if host_encode:
1311-
write_ctx: contextlib.AbstractContextManager[None] = contextlib.nullcontext()
1312-
read_ctx: contextlib.AbstractContextManager[None] = gpu_context()
1313-
write_data = np.arange(16, dtype="int32").reshape(4, 4)
1314-
read_data = cp.array(write_data)
1315-
xp = cp
1316-
else:
1317-
write_ctx = gpu_context()
1318-
read_ctx = contextlib.nullcontext()
1319-
write_data = cp.arange(16, dtype="int32").reshape(4, 4)
1320-
read_data = write_data.get()
1321-
xp = np
1322-
1323-
with write_ctx:
1324-
z = zarr.create_array(
1325-
store=zarr.storage.MemoryStore(),
1326-
shape=write_data.shape,
1327-
chunks=(4, 4),
1328-
dtype=write_data.dtype,
1329-
)
1330-
z[:] = write_data
1331-
1332-
with read_ctx:
1333-
result = z[:]
1334-
assert isinstance(result, type(read_data))
1335-
xp.testing.assert_array_equal(result, read_data)
1336-
1337-
13381298
def test_v2_without_compressor() -> None:
13391299
# Make sure it's possible to set no compressor for v2 arrays
13401300
arr = zarr.create(store={}, shape=(1), dtype="uint8", zarr_format=2, compressor=None)
@@ -1454,5 +1414,3 @@ def test_unimplemented_kwarg_warnings(kwarg_name: str) -> None:
14541414
kwargs = {kwarg_name: 1}
14551415
with pytest.warns(RuntimeWarning, match=".* is not yet implemented"):
14561416
zarr.create(shape=(1,), **kwargs) # type: ignore[arg-type]
1457-
1458-

tests/test_codecs/test_nvcomp.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import typing
3+
from collections.abc import Iterator
34

45
import numpy as np
56
import pytest
@@ -8,7 +9,6 @@
89
from zarr.abc.store import Store
910
from zarr.buffer.gpu import buffer_prototype
1011
from zarr.codecs import NvcompZstdCodec
11-
from zarr.codecs.zstd import ZstdCodec
1212
from zarr.core.array_spec import ArrayConfig, ArraySpec
1313
from zarr.storage import StorePath
1414
from zarr.testing.utils import gpu_test
@@ -65,54 +65,49 @@ def test_nvcomp_zstd(store: Store, checksum: bool, selection: tuple[slice, slice
6565

6666

6767
@gpu_test # type: ignore[misc,unused-ignore]
68-
@pytest.mark.parametrize("host_encode", [True, False], ids=["host_encode", "device_encode"])
69-
def test_nvcomp_zstd_matches(memory_store: zarr.storage.MemoryStore, host_encode: bool) -> None:
70-
# Test to ensure that
68+
@pytest.mark.parametrize("host_encode", [True, False])
69+
def test_gpu_codec_compatibility(host_encode: bool) -> None:
70+
# Ensure that the we can decode CPU-encoded data with the GPU
71+
# and GPU-encoded data with the CPU
7172
import cupy as cp
7273

7374
@contextlib.contextmanager
74-
def enable_gpu():
75+
def gpu_context() -> Iterator[None]:
7576
with zarr.config.enable_gpu():
7677
yield
7778

7879
if host_encode:
79-
write_ctx = contextlib.nullcontext()
80-
read_ctx = enable_gpu()
81-
data = np.arange(0, 256, dtype="uint16").reshape((16, 16))
80+
# CPU encode, GPU decode
81+
write_ctx: contextlib.AbstractContextManager[None] = contextlib.nullcontext()
82+
read_ctx: contextlib.AbstractContextManager[None] = gpu_context()
83+
write_data = np.arange(16, dtype="int32").reshape(4, 4)
84+
read_data = cp.array(write_data)
85+
xp = cp
8286
else:
83-
write_ctx = zarr.config.enable_gpu()
84-
read_ctx = enable_gpu()
85-
data = cp.arange(0, 256, dtype="uint16").reshape((16, 16))
87+
# GPU encode, CPU decode
88+
write_ctx = gpu_context()
89+
read_ctx = contextlib.nullcontext()
90+
write_data = cp.arange(16, dtype="int32").reshape(4, 4)
91+
read_data = write_data.get()
92+
xp = np
93+
94+
store = zarr.storage.MemoryStore()
8695

8796
with write_ctx:
88-
a = zarr.create_array(
89-
store=memory_store,
90-
name="data",
91-
shape=data.shape,
97+
z = zarr.create_array(
98+
store=store,
99+
shape=write_data.shape,
92100
chunks=(4, 4),
93-
dtype=data.dtype,
94-
fill_value=0,
101+
dtype=write_data.dtype,
95102
)
96-
a[:, :] = data
97-
assert a.metadata.zarr_format == 3
98-
assert a.metadata.codecs[-1].to_dict()["name"] == "zstd"
103+
z[:] = write_data
99104

100105
with read_ctx:
101-
b = zarr.open_array(a.store_path, mode="r")
102-
assert b.metadata.zarr_format == 3
103-
104-
zstd_codec = b.metadata.codecs[-1]
105-
assert zstd_codec.to_dict()["name"] == "zstd"
106-
107-
if host_encode:
108-
assert isinstance(zstd_codec, NvcompZstdCodec)
109-
else:
110-
assert isinstance(zstd_codec, ZstdCodec)
111-
112-
if host_encode:
113-
cp.testing.assert_array_equal(data, b[:, :])
114-
else:
115-
np.testing.assert_array_equal(data.get(), b[:, :])
106+
# We need to reopen z, because `z.codec_pipeline` is set at creation
107+
z = zarr.open_array(store=store, mode="r")
108+
result = z[:]
109+
assert isinstance(result, type(read_data))
110+
xp.testing.assert_array_equal(result, read_data)
116111

117112

118113
@gpu_test # type: ignore[misc,unused-ignore]

0 commit comments

Comments
 (0)