|
1 | 1 | import contextlib |
2 | 2 | import typing |
| 3 | +from collections.abc import Iterator |
3 | 4 |
|
4 | 5 | import numpy as np |
5 | 6 | import pytest |
|
8 | 9 | from zarr.abc.store import Store |
9 | 10 | from zarr.buffer.gpu import buffer_prototype |
10 | 11 | from zarr.codecs import NvcompZstdCodec |
11 | | -from zarr.codecs.zstd import ZstdCodec |
12 | 12 | from zarr.core.array_spec import ArrayConfig, ArraySpec |
13 | 13 | from zarr.storage import StorePath |
14 | 14 | from zarr.testing.utils import gpu_test |
@@ -65,54 +65,49 @@ def test_nvcomp_zstd(store: Store, checksum: bool, selection: tuple[slice, slice |
65 | 65 |
|
66 | 66 |
|
67 | 67 | @gpu_test # type: ignore[misc,unused-ignore] |
68 | | -@pytest.mark.parametrize("host_encode", [True, False], ids=["host_encode", "device_encode"]) |
69 | | -def test_nvcomp_zstd_matches(memory_store: zarr.storage.MemoryStore, host_encode: bool) -> None: |
70 | | - # Test to ensure that |
| 68 | +@pytest.mark.parametrize("host_encode", [True, False]) |
| 69 | +def test_gpu_codec_compatibility(host_encode: bool) -> None: |
| 70 | + # Ensure that the we can decode CPU-encoded data with the GPU |
| 71 | + # and GPU-encoded data with the CPU |
71 | 72 | import cupy as cp |
72 | 73 |
|
73 | 74 | @contextlib.contextmanager |
74 | | - def enable_gpu(): |
| 75 | + def gpu_context() -> Iterator[None]: |
75 | 76 | with zarr.config.enable_gpu(): |
76 | 77 | yield |
77 | 78 |
|
78 | 79 | if host_encode: |
79 | | - write_ctx = contextlib.nullcontext() |
80 | | - read_ctx = enable_gpu() |
81 | | - data = np.arange(0, 256, dtype="uint16").reshape((16, 16)) |
| 80 | + # CPU encode, GPU decode |
| 81 | + write_ctx: contextlib.AbstractContextManager[None] = contextlib.nullcontext() |
| 82 | + read_ctx: contextlib.AbstractContextManager[None] = gpu_context() |
| 83 | + write_data = np.arange(16, dtype="int32").reshape(4, 4) |
| 84 | + read_data = cp.array(write_data) |
| 85 | + xp = cp |
82 | 86 | else: |
83 | | - write_ctx = zarr.config.enable_gpu() |
84 | | - read_ctx = enable_gpu() |
85 | | - data = cp.arange(0, 256, dtype="uint16").reshape((16, 16)) |
| 87 | + # GPU encode, CPU decode |
| 88 | + write_ctx = gpu_context() |
| 89 | + read_ctx = contextlib.nullcontext() |
| 90 | + write_data = cp.arange(16, dtype="int32").reshape(4, 4) |
| 91 | + read_data = write_data.get() |
| 92 | + xp = np |
| 93 | + |
| 94 | + store = zarr.storage.MemoryStore() |
86 | 95 |
|
87 | 96 | with write_ctx: |
88 | | - a = zarr.create_array( |
89 | | - store=memory_store, |
90 | | - name="data", |
91 | | - shape=data.shape, |
| 97 | + z = zarr.create_array( |
| 98 | + store=store, |
| 99 | + shape=write_data.shape, |
92 | 100 | chunks=(4, 4), |
93 | | - dtype=data.dtype, |
94 | | - fill_value=0, |
| 101 | + dtype=write_data.dtype, |
95 | 102 | ) |
96 | | - a[:, :] = data |
97 | | - assert a.metadata.zarr_format == 3 |
98 | | - assert a.metadata.codecs[-1].to_dict()["name"] == "zstd" |
| 103 | + z[:] = write_data |
99 | 104 |
|
100 | 105 | with read_ctx: |
101 | | - b = zarr.open_array(a.store_path, mode="r") |
102 | | - assert b.metadata.zarr_format == 3 |
103 | | - |
104 | | - zstd_codec = b.metadata.codecs[-1] |
105 | | - assert zstd_codec.to_dict()["name"] == "zstd" |
106 | | - |
107 | | - if host_encode: |
108 | | - assert isinstance(zstd_codec, NvcompZstdCodec) |
109 | | - else: |
110 | | - assert isinstance(zstd_codec, ZstdCodec) |
111 | | - |
112 | | - if host_encode: |
113 | | - cp.testing.assert_array_equal(data, b[:, :]) |
114 | | - else: |
115 | | - np.testing.assert_array_equal(data.get(), b[:, :]) |
| 106 | + # We need to reopen z, because `z.codec_pipeline` is set at creation |
| 107 | + z = zarr.open_array(store=store, mode="r") |
| 108 | + result = z[:] |
| 109 | + assert isinstance(result, type(read_data)) |
| 110 | + xp.testing.assert_array_equal(result, read_data) |
116 | 111 |
|
117 | 112 |
|
118 | 113 | @gpu_test # type: ignore[misc,unused-ignore] |
|
0 commit comments