|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import json |
3 | 4 | import re |
4 | 5 | from typing import TYPE_CHECKING, Any |
5 | 6 |
|
|
9 | 10 |
|
10 | 11 | import zarr |
11 | 12 | from zarr.core.buffer import Buffer, cpu, gpu |
| 13 | +from zarr.core.sync import sync |
12 | 14 | from zarr.errors import ZarrUserWarning |
13 | 15 | from zarr.storage import GpuMemoryStore, MemoryStore |
14 | 16 | from zarr.testing.store import StoreTests |
15 | 17 | from zarr.testing.utils import gpu_test |
16 | 18 |
|
17 | 19 | if TYPE_CHECKING: |
| 20 | + from zarr.core.buffer import BufferPrototype |
18 | 21 | from zarr.core.common import ZarrFormat |
19 | 22 |
|
20 | 23 |
|
@@ -76,75 +79,53 @@ async def test_deterministic_size( |
76 | 79 | np.testing.assert_array_equal(a[:3], 1) |
77 | 80 | np.testing.assert_array_equal(a[3:], 0) |
78 | 81 |
|
79 | | - async def test_get_bytes_with_prototype_none(self, store: MemoryStore) -> None: |
80 | | - """Test that get_bytes_async works with prototype=None.""" |
81 | | - from zarr.core.buffer.core import default_buffer_prototype |
82 | | - |
| 82 | + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) |
| 83 | + async def test_get_bytes_with_prototype_none( |
| 84 | + self, store: MemoryStore, buffer_cls: None | BufferPrototype |
| 85 | + ) -> None: |
| 86 | + """Test that get_bytes works with prototype=None.""" |
83 | 87 | data = b"hello world" |
84 | 88 | key = "test_key" |
85 | 89 | await self.set(store, key, self.buffer_cls.from_bytes(data)) |
86 | 90 |
|
87 | | - # Test with None (default) |
88 | | - result_none = await store.get_bytes(key) |
89 | | - assert result_none == data |
90 | | - |
91 | | - # Test with explicit prototype |
92 | | - result_proto = await store.get_bytes(key, prototype=default_buffer_prototype()) |
93 | | - assert result_proto == data |
94 | | - |
95 | | - def test_get_bytes_sync_with_prototype_none(self, store: MemoryStore) -> None: |
96 | | - """Test that get_bytes works with prototype=None.""" |
97 | | - from zarr.core.buffer.core import default_buffer_prototype |
98 | | - from zarr.core.sync import sync |
| 91 | + result = await store.get_bytes(key, prototype=buffer_cls) |
| 92 | + assert result == data |
99 | 93 |
|
| 94 | + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) |
| 95 | + def test_get_bytes_sync_with_prototype_none( |
| 96 | + self, store: MemoryStore, buffer_cls: None | BufferPrototype |
| 97 | + ) -> None: |
| 98 | + """Test that get_bytes_sync works with prototype=None.""" |
100 | 99 | data = b"hello world" |
101 | 100 | key = "test_key" |
102 | 101 | sync(self.set(store, key, self.buffer_cls.from_bytes(data))) |
103 | 102 |
|
104 | | - # Test with None (default) |
105 | | - result_none = store.get_bytes_sync(key) |
106 | | - assert result_none == data |
| 103 | + result = store.get_bytes_sync(key, prototype=buffer_cls) |
| 104 | + assert result == data |
107 | 105 |
|
108 | | - # Test with explicit prototype |
109 | | - result_proto = store.get_bytes_sync(key, prototype=default_buffer_prototype()) |
110 | | - assert result_proto == data |
111 | | - |
112 | | - async def test_get_json_with_prototype_none(self, store: MemoryStore) -> None: |
| 106 | + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) |
| 107 | + async def test_get_json_with_prototype_none( |
| 108 | + self, store: MemoryStore, buffer_cls: None | BufferPrototype |
| 109 | + ) -> None: |
113 | 110 | """Test that get_json works with prototype=None.""" |
114 | | - import json |
115 | | - |
116 | | - from zarr.core.buffer.core import default_buffer_prototype |
117 | | - |
118 | 111 | data = {"foo": "bar", "number": 42} |
119 | 112 | key = "test.json" |
120 | 113 | await self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode())) |
121 | 114 |
|
122 | | - # Test with None (default) |
123 | | - result_none = await store.get_json(key) |
124 | | - assert result_none == data |
125 | | - |
126 | | - # Test with explicit prototype |
127 | | - result_proto = await store.get_json(key, prototype=default_buffer_prototype()) |
128 | | - assert result_proto == data |
129 | | - |
130 | | - def test_get_json_sync_with_prototype_none(self, store: MemoryStore) -> None: |
131 | | - """Test that get_json works with prototype=None.""" |
132 | | - import json |
133 | | - |
134 | | - from zarr.core.buffer.core import default_buffer_prototype |
135 | | - from zarr.core.sync import sync |
| 115 | + result = await store.get_json(key, prototype=buffer_cls) |
| 116 | + assert result == data |
136 | 117 |
|
| 118 | + @pytest.mark.parametrize("buffer_cls", [None, cpu.buffer_prototype]) |
| 119 | + def test_get_json_sync_with_prototype_none( |
| 120 | + self, store: MemoryStore, buffer_cls: None | BufferPrototype |
| 121 | + ) -> None: |
| 122 | + """Test that get_json_sync works with prototype=None.""" |
137 | 123 | data = {"foo": "bar", "number": 42} |
138 | 124 | key = "test.json" |
139 | 125 | sync(self.set(store, key, self.buffer_cls.from_bytes(json.dumps(data).encode()))) |
140 | 126 |
|
141 | | - # Test with None (default) |
142 | | - result_none = store.get_json_sync(key) |
143 | | - assert result_none == data |
144 | | - |
145 | | - # Test with explicit prototype |
146 | | - result_proto = store.get_json_sync(key, prototype=default_buffer_prototype()) |
147 | | - assert result_proto == data |
| 127 | + result = store.get_json_sync(key, prototype=buffer_cls) |
| 128 | + assert result == data |
148 | 129 |
|
149 | 130 |
|
150 | 131 | # TODO: fix this warning |
|
0 commit comments