|
10 | 10 | from zarr.storage._common import StorePath |
11 | 11 |
|
12 | 12 | if TYPE_CHECKING: |
13 | | - from collections.abc import Callable |
| 13 | + from collections.abc import Callable, Iterator |
14 | 14 | from pathlib import Path |
15 | 15 |
|
16 | 16 | from zarr.abc.codec import Codec |
@@ -1257,7 +1257,7 @@ def test_api_exports() -> None: |
1257 | 1257 | assert zarr.api.asynchronous.__all__ == zarr.api.synchronous.__all__ |
1258 | 1258 |
|
1259 | 1259 |
|
1260 | | -@gpu_test |
| 1260 | +@gpu_test # type: ignore[misc] |
1261 | 1261 | @pytest.mark.parametrize( |
1262 | 1262 | "store", |
1263 | 1263 | ["local", "memory", "zip"], |
@@ -1295,6 +1295,46 @@ def test_gpu_basic(store: Store, zarr_format: ZarrFormat | None, codec: str | Co |
1295 | 1295 | cp.testing.assert_array_equal(result, src[:10, :10]) |
1296 | 1296 |
|
1297 | 1297 |
|
| 1298 | +@gpu_test # type: ignore[misc] |
| 1299 | +@pytest.mark.parametrize("host_encode", [True, False]) |
| 1300 | +def test_gpu_codec_compatibility(host_encode: bool) -> None: |
| 1301 | + # Ensure that the we can decode CPU-encoded data with the GPU |
| 1302 | + # and GPU-encoded data with the CPU |
| 1303 | + import cupy as cp |
| 1304 | + |
| 1305 | + @contextlib.contextmanager |
| 1306 | + def gpu_context() -> Iterator[None]: |
| 1307 | + with zarr.config.enable_gpu(): |
| 1308 | + yield |
| 1309 | + |
| 1310 | + if host_encode: |
| 1311 | + write_ctx: contextlib.AbstractContextManager[None] = contextlib.nullcontext() |
| 1312 | + read_ctx: contextlib.AbstractContextManager[None] = gpu_context() |
| 1313 | + write_data = np.arange(16, dtype="int32").reshape(4, 4) |
| 1314 | + read_data = cp.array(write_data) |
| 1315 | + xp = cp |
| 1316 | + else: |
| 1317 | + write_ctx = gpu_context() |
| 1318 | + read_ctx = contextlib.nullcontext() |
| 1319 | + write_data = cp.arange(16, dtype="int32").reshape(4, 4) |
| 1320 | + read_data = write_data.get() |
| 1321 | + xp = np |
| 1322 | + |
| 1323 | + with write_ctx: |
| 1324 | + z = zarr.create_array( |
| 1325 | + store=zarr.storage.MemoryStore(), |
| 1326 | + shape=write_data.shape, |
| 1327 | + chunks=(4, 4), |
| 1328 | + dtype=write_data.dtype, |
| 1329 | + ) |
| 1330 | + z[:] = write_data |
| 1331 | + |
| 1332 | + with read_ctx: |
| 1333 | + result = z[:] |
| 1334 | + assert isinstance(result, type(read_data)) |
| 1335 | + xp.testing.assert_array_equal(result, read_data) |
| 1336 | + |
| 1337 | + |
1298 | 1338 | def test_v2_without_compressor() -> None: |
1299 | 1339 | # Make sure it's possible to set no compressor for v2 arrays |
1300 | 1340 | arr = zarr.create(store={}, shape=(1), dtype="uint8", zarr_format=2, compressor=None) |
@@ -1414,3 +1454,5 @@ def test_unimplemented_kwarg_warnings(kwarg_name: str) -> None: |
1414 | 1454 | kwargs = {kwarg_name: 1} |
1415 | 1455 | with pytest.warns(RuntimeWarning, match=".* is not yet implemented"): |
1416 | 1456 | zarr.create(shape=(1,), **kwargs) # type: ignore[arg-type] |
| 1457 | + |
| 1458 | + |
0 commit comments