2323 Store ,
2424 SuffixByteRequest ,
2525)
26- from zarr .core .buffer import Buffer , default_buffer_prototype
26+ from zarr .core .buffer import Buffer , cpu , default_buffer_prototype
2727from zarr .core .sync import _collect_aiterator , sync
2828from zarr .storage ._utils import _normalize_byte_range_index
2929from zarr .testing .utils import assert_bytes_equal
@@ -202,6 +202,15 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None:
202202 ):
203203 await reader .delete ("foo" )
204204
205+ @pytest .mark .parametrize (
206+ "prototype" ,
207+ [
208+ None , # Should use store's default buffer class
209+ default_buffer_prototype (), # BufferPrototype instance
210+ default_buffer_prototype ().buffer , # Raw Buffer class (cpu.Buffer)
211+ ],
212+ ids = ["prototype=None" , "prototype=BufferPrototype" , "prototype=Buffer" ],
213+ )
205214 @pytest .mark .parametrize ("key" , ["c/0" , "foo/c/0.0" , "foo/0/0" ])
206215 @pytest .mark .parametrize (
207216 ("data" , "byte_range" ),
@@ -213,13 +222,15 @@ async def test_with_read_only_store(self, open_kwargs: dict[str, Any]) -> None:
213222 (b"" , None ),
214223 ],
215224 )
216- async def test_get (self , store : S , key : str , data : bytes , byte_range : ByteRequest ) -> None :
225+ async def test_get (
226+ self , store : S , key : str , data : bytes , byte_range : ByteRequest , prototype : BufferLike | None
227+ ) -> None :
217228 """
218229 Ensure that data can be read from the store using the store.get method.
219230 """
220231 data_buf = self .buffer_cls .from_bytes (data )
221232 await self .set (store , key , data_buf )
222- observed = await store .get (key , prototype = default_buffer_prototype () , byte_range = byte_range )
233+ observed = await store .get (key , prototype = prototype , byte_range = byte_range )
223234 start , stop = _normalize_byte_range_index (data_buf , byte_range = byte_range )
224235 expected = data_buf [start :stop ]
225236 assert_bytes_equal (observed , expected )
@@ -244,32 +255,6 @@ async def test_get_raises(self, store: S) -> None:
244255 with pytest .raises ((ValueError , TypeError ), match = r"Unexpected byte_range, got.*" ):
245256 await store .get ("c/0" , prototype = default_buffer_prototype (), byte_range = (0 , 2 )) # type: ignore[arg-type]
246257
247- @pytest .mark .parametrize (
248- "prototype" ,
249- [
250- None , # Should use store's default buffer class
251- default_buffer_prototype (), # BufferPrototype instance
252- default_buffer_prototype ().buffer , # Raw Buffer class (cpu.Buffer)
253- ],
254- ids = ["prototype=None" , "prototype=BufferPrototype" , "prototype=Buffer" ],
255- )
256- async def test_get_with_buffer_like (self , store : S , prototype : BufferLike | None ) -> None :
257- """
258- Test that store.get() works with all BufferLike variants:
259- - None (uses store's default)
260- - BufferPrototype instance
261- - Raw Buffer class
262- """
263- data = b"\x01 \x02 \x03 \x04 "
264- key = "test_buffer_like"
265- data_buf = self .buffer_cls .from_bytes (data )
266- await self .set (store , key , data_buf )
267-
268- # Get with the parametrized prototype
269- observed = await store .get (key , prototype = prototype )
270- assert observed is not None
271- assert_bytes_equal (observed , data_buf )
272-
273258 async def test_get_many (self , store : S ) -> None :
274259 """
275260 Ensure that multiple keys can be retrieved at once with the _get_many method.
@@ -358,6 +343,15 @@ async def test_set_many(self, store: S) -> None:
358343 for k , v in store_dict .items ():
359344 assert (await self .get (store , k )).to_bytes () == v .to_bytes ()
360345
346+ @pytest .mark .parametrize (
347+ "prototype" ,
348+ [
349+ None , # Should use store's default buffer class
350+ default_buffer_prototype (), # BufferPrototype instance
351+ default_buffer_prototype ().buffer , # Raw Buffer class (cpu.Buffer)
352+ ],
353+ ids = ["prototype=None" , "prototype=BufferPrototype" , "prototype=Buffer" ],
354+ )
361355 @pytest .mark .parametrize (
362356 "key_ranges" ,
363357 [
@@ -372,65 +366,13 @@ async def test_set_many(self, store: S) -> None:
372366 ],
373367 )
374368 async def test_get_partial_values (
375- self , store : S , key_ranges : list [tuple [str , ByteRequest ]]
369+ self , store : S , key_ranges : list [tuple [str , ByteRequest ]], prototype : BufferLike | None
376370 ) -> None :
377371 # put all of the data
378372 for key , _ in key_ranges :
379373 await self .set (store , key , self .buffer_cls .from_bytes (bytes (key , encoding = "utf-8" )))
380374
381375 # read back just part of it
382- observed_maybe = await store .get_partial_values (
383- prototype = default_buffer_prototype (), key_ranges = key_ranges
384- )
385-
386- observed : list [Buffer ] = []
387- expected : list [Buffer ] = []
388-
389- for obs in observed_maybe :
390- assert obs is not None
391- observed .append (obs )
392-
393- for idx in range (len (observed )):
394- key , byte_range = key_ranges [idx ]
395- result = await store .get (
396- key , prototype = default_buffer_prototype (), byte_range = byte_range
397- )
398- assert result is not None
399- expected .append (result )
400-
401- assert all (
402- obs .to_bytes () == exp .to_bytes () for obs , exp in zip (observed , expected , strict = True )
403- )
404-
405- @pytest .mark .parametrize (
406- "prototype" ,
407- [
408- None , # Should use store's default buffer class
409- default_buffer_prototype (), # BufferPrototype instance
410- default_buffer_prototype ().buffer , # Raw Buffer class (cpu.Buffer)
411- ],
412- ids = ["prototype=None" , "prototype=BufferPrototype" , "prototype=Buffer" ],
413- )
414- async def test_get_partial_values_with_buffer_like (
415- self , store : S , prototype : BufferLike | None
416- ) -> None :
417- """
418- Test that store.get_partial_values() works with all BufferLike variants:
419- - None (uses store's default)
420- - BufferPrototype instance
421- - Raw Buffer class
422- """
423- key_ranges : list [tuple [str , ByteRequest | None ]] = [
424- ("c/0" , RangeByteRequest (0 , 2 )),
425- ("c/1" , None ),
426- ("c/2" , SuffixByteRequest (2 )),
427- ]
428-
429- # put all of the data
430- for key , _ in key_ranges :
431- await self .set (store , key , self .buffer_cls .from_bytes (bytes (key , encoding = "utf-8" )))
432-
433- # read back with the parametrized prototype
434376 observed_maybe = await store .get_partial_values (prototype = prototype , key_ranges = key_ranges )
435377
436378 observed : list [Buffer ] = []
@@ -442,7 +384,7 @@ async def test_get_partial_values_with_buffer_like(
442384
443385 for idx in range (len (observed )):
444386 key , byte_range = key_ranges [idx ]
445- result = await store .get (key , prototype = prototype , byte_range = byte_range )
387+ result = await store .get (key , prototype = cpu . Buffer , byte_range = byte_range )
446388 assert result is not None
447389 expected .append (result )
448390
0 commit comments