Skip to content

Commit dd69543

Browse files
committed
added failing compatibility test
1 parent 398b4d1 commit dd69543

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

tests/test_api.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from zarr.storage._common import StorePath
1111

1212
if TYPE_CHECKING:
13-
from collections.abc import Callable
13+
from collections.abc import Callable, Iterator
1414
from pathlib import Path
1515

1616
from zarr.abc.codec import Codec
@@ -1257,7 +1257,7 @@ def test_api_exports() -> None:
12571257
assert zarr.api.asynchronous.__all__ == zarr.api.synchronous.__all__
12581258

12591259

1260-
@gpu_test
1260+
@gpu_test # type: ignore[misc]
12611261
@pytest.mark.parametrize(
12621262
"store",
12631263
["local", "memory", "zip"],
@@ -1295,6 +1295,46 @@ def test_gpu_basic(store: Store, zarr_format: ZarrFormat | None, codec: str | Co
12951295
cp.testing.assert_array_equal(result, src[:10, :10])
12961296

12971297

1298+
@gpu_test # type: ignore[misc]
1299+
@pytest.mark.parametrize("host_encode", [True, False])
1300+
def test_gpu_codec_compatibility(host_encode: bool) -> None:
1301+
# Ensure that the we can decode CPU-encoded data with the GPU
1302+
# and GPU-encoded data with the CPU
1303+
import cupy as cp
1304+
1305+
@contextlib.contextmanager
1306+
def gpu_context() -> Iterator[None]:
1307+
with zarr.config.enable_gpu():
1308+
yield
1309+
1310+
if host_encode:
1311+
write_ctx: contextlib.AbstractContextManager[None] = contextlib.nullcontext()
1312+
read_ctx: contextlib.AbstractContextManager[None] = gpu_context()
1313+
write_data = np.arange(16, dtype="int32").reshape(4, 4)
1314+
read_data = cp.array(write_data)
1315+
xp = cp
1316+
else:
1317+
write_ctx = gpu_context()
1318+
read_ctx = contextlib.nullcontext()
1319+
write_data = cp.arange(16, dtype="int32").reshape(4, 4)
1320+
read_data = write_data.get()
1321+
xp = np
1322+
1323+
with write_ctx:
1324+
z = zarr.create_array(
1325+
store=zarr.storage.MemoryStore(),
1326+
shape=write_data.shape,
1327+
chunks=(4, 4),
1328+
dtype=write_data.dtype,
1329+
)
1330+
z[:] = write_data
1331+
1332+
with read_ctx:
1333+
result = z[:]
1334+
assert isinstance(result, type(read_data))
1335+
xp.testing.assert_array_equal(result, read_data)
1336+
1337+
12981338
def test_v2_without_compressor() -> None:
12991339
# Make sure it's possible to set no compressor for v2 arrays
13001340
arr = zarr.create(store={}, shape=(1), dtype="uint8", zarr_format=2, compressor=None)
@@ -1414,3 +1454,5 @@ def test_unimplemented_kwarg_warnings(kwarg_name: str) -> None:
14141454
kwargs = {kwarg_name: 1}
14151455
with pytest.warns(RuntimeWarning, match=".* is not yet implemented"):
14161456
zarr.create(shape=(1,), **kwargs) # type: ignore[arg-type]
1457+
1458+

0 commit comments

Comments
 (0)