Skip to content

Commit b110768

Browse files
committed
add BufferLike as buffer parameter for store methods that allocate memory
1 parent 0a97eb4 commit b110768

File tree

12 files changed

+326
-115
lines changed

12 files changed

+326
-115
lines changed

src/zarr/abc/store.py

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@
77
from itertools import starmap
88
from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable
99

10+
from zarr.core.buffer import Buffer, BufferPrototype
1011
from zarr.core.sync import sync
1112

1213
if TYPE_CHECKING:
1314
from collections.abc import AsyncGenerator, AsyncIterator, Iterable
1415
from types import TracebackType
1516
from typing import Any, Self, TypeAlias
1617

17-
from zarr.core.buffer import Buffer, BufferPrototype
18+
__all__ = ["BufferLike", "ByteGetter", "ByteSetter", "Store", "set_or_delete"]
1819

19-
__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"]
20+
BufferLike = type[Buffer] | BufferPrototype
2021

2122

2223
@dataclass
@@ -183,20 +184,31 @@ def __eq__(self, value: object) -> bool:
183184
"""Equality comparison."""
184185
...
185186

187+
@abstractmethod
188+
def _get_default_buffer_class(self) -> type[Buffer]:
189+
"""
190+
Get the default buffer class for this store.
191+
"""
192+
...
193+
186194
@abstractmethod
187195
async def get(
188196
self,
189197
key: str,
190-
prototype: BufferPrototype,
198+
prototype: BufferLike | None = None,
191199
byte_range: ByteRequest | None = None,
192200
) -> Buffer | None:
193201
"""Retrieve the value associated with a given key.
194202
195203
Parameters
196204
----------
197205
key : str
198-
prototype : BufferPrototype
199-
The prototype of the output buffer. Stores may support a default buffer prototype.
206+
prototype : BufferLike | None, optional
207+
The prototype of the output buffer.
208+
Can be either a Buffer class or an instance of `BufferPrototype`, in which the
209+
`buffer` attribute will be used.
210+
If `None`, the default buffer class for this store will be retrieved via the
211+
``_get_default_buffer_class`` method.
200212
byte_range : ByteRequest, optional
201213
ByteRequest may be one of the following. If not provided, all data associated with the key is retrieved.
202214
- RangeByteRequest(int, int): Request a specific range of bytes in the form (start, end). The end is exclusive. If the given range is zero-length or starts after the end of the object, an error will be returned. Additionally, if the range ends after the end of the object, the entire remainder of the object will be returned. Otherwise, the exact requested range will be returned.
@@ -210,7 +222,11 @@ async def get(
210222
...
211223

212224
async def get_bytes(
213-
self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None
225+
self,
226+
key: str,
227+
*,
228+
prototype: BufferLike | None = None,
229+
byte_range: ByteRequest | None = None,
214230
) -> bytes:
215231
"""
216232
Retrieve raw bytes from the store asynchronously.
@@ -222,8 +238,12 @@ async def get_bytes(
222238
----------
223239
key : str
224240
The key identifying the data to retrieve.
225-
prototype : BufferPrototype
226-
The buffer prototype to use for reading the data.
241+
prototype : BufferLike | None, optional
242+
The prototype of the output buffer.
243+
Can be either a Buffer class or an instance of `BufferPrototype`, in which the
244+
`buffer` attribute will be used.
245+
If `None`, the default buffer prototype for this store will be retrieved via the
246+
``_get_default_buffer_class`` method.
227247
byte_range : ByteRequest, optional
228248
If specified, only retrieve a portion of the stored data.
229249
Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``.
@@ -258,7 +278,11 @@ async def get_bytes(
258278
return buffer.to_bytes()
259279

260280
def get_bytes_sync(
261-
self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None
281+
self,
282+
key: str = "",
283+
*,
284+
prototype: BufferLike | None = None,
285+
byte_range: ByteRequest | None = None,
262286
) -> bytes:
263287
"""
264288
Retrieve raw bytes from the store synchronously.
@@ -271,8 +295,12 @@ def get_bytes_sync(
271295
----------
272296
key : str, optional
273297
The key identifying the data to retrieve. Defaults to an empty string.
274-
prototype : BufferPrototype
275-
The buffer prototype to use for reading the data.
298+
prototype : BufferLike | None, optional
299+
The prototype of the output buffer.
300+
Can be either a Buffer class or an instance of `BufferPrototype`, in which the
301+
`buffer` attribute will be used.
302+
If `None`, the default buffer prototype for this store will be retrieved via the
303+
``_get_default_buffer_class`` method.
276304
byte_range : ByteRequest, optional
277305
If specified, only retrieve a portion of the stored data.
278306
Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``.
@@ -309,7 +337,11 @@ def get_bytes_sync(
309337
return sync(self.get_bytes(key, prototype=prototype, byte_range=byte_range))
310338

311339
async def get_json(
312-
self, key: str, *, prototype: BufferPrototype, byte_range: ByteRequest | None = None
340+
self,
341+
key: str,
342+
*,
343+
prototype: BufferLike | None = None,
344+
byte_range: ByteRequest | None = None,
313345
) -> Any:
314346
"""
315347
Retrieve and parse JSON data from the store asynchronously.
@@ -321,8 +353,12 @@ async def get_json(
321353
----------
322354
key : str
323355
The key identifying the JSON data to retrieve.
324-
prototype : BufferPrototype
325-
The buffer prototype to use for reading the data.
356+
prototype : BufferLike | None, optional
357+
The prototype of the output buffer.
358+
Can be either a Buffer class or an instance of `BufferPrototype`, in which the
359+
`buffer` attribute will be used.
360+
If `None`, the default buffer prototype for this store will be retrieved via the
361+
``_get_default_buffer_class`` method.
326362
byte_range : ByteRequest, optional
327363
If specified, only retrieve a portion of the stored data.
328364
Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``.
@@ -359,7 +395,11 @@ async def get_json(
359395
return json.loads(await self.get_bytes(key, prototype=prototype, byte_range=byte_range))
360396

361397
def get_json_sync(
362-
self, key: str = "", *, prototype: BufferPrototype, byte_range: ByteRequest | None = None
398+
self,
399+
key: str = "",
400+
*,
401+
prototype: BufferLike | None = None,
402+
byte_range: ByteRequest | None = None,
363403
) -> Any:
364404
"""
365405
Retrieve and parse JSON data from the store synchronously.
@@ -372,8 +412,12 @@ def get_json_sync(
372412
----------
373413
key : str, optional
374414
The key identifying the JSON data to retrieve. Defaults to an empty string.
375-
prototype : BufferPrototype
376-
The buffer prototype to use for reading the data.
415+
prototype : BufferLike | None, optional
416+
The prototype of the output buffer.
417+
Can be either a Buffer class or an instance of `BufferPrototype`, in which the
418+
`buffer` attribute will be used.
419+
If `None`, the default buffer prototype for this store will be retrieved via the
420+
``_get_default_buffer_class`` method.
377421
byte_range : ByteRequest, optional
378422
If specified, only retrieve a portion of the stored data.
379423
Can be a ``RangeByteRequest``, ``OffsetByteRequest``, or ``SuffixByteRequest``.
@@ -417,15 +461,19 @@ def get_json_sync(
417461
@abstractmethod
418462
async def get_partial_values(
419463
self,
420-
prototype: BufferPrototype,
464+
prototype: BufferLike | None,
421465
key_ranges: Iterable[tuple[str, ByteRequest | None]],
422466
) -> list[Buffer | None]:
423467
"""Retrieve possibly partial values from given key_ranges.
424468
425469
Parameters
426470
----------
427-
prototype : BufferPrototype
428-
The prototype of the output buffer. Stores may support a default buffer prototype.
471+
prototype : BufferLike | None
472+
The prototype of the output buffer.
473+
Can be either a Buffer class or an instance of `BufferPrototype`, in which the
474+
`buffer` attribute will be used.
475+
If `None`, the default buffer prototype for this store will be retrieved via the
476+
``_get_default_buffer_class`` method.
429477
key_ranges : Iterable[tuple[str, tuple[int | None, int | None]]]
430478
Ordered set of key, range pairs, a key may occur multiple times with different ranges
431479
@@ -597,7 +645,7 @@ def close(self) -> None:
597645
self._is_open = False
598646

599647
async def _get_many(
600-
self, requests: Iterable[tuple[str, BufferPrototype, ByteRequest | None]]
648+
self, requests: Iterable[tuple[str, BufferLike | None, ByteRequest | None]]
601649
) -> AsyncGenerator[tuple[str, Buffer | None], None]:
602650
"""
603651
Retrieve a collection of objects from storage. In general this method does not guarantee

src/zarr/experimental/cache_store.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
from collections import OrderedDict
77
from typing import TYPE_CHECKING, Any, Literal
88

9-
from zarr.abc.store import ByteRequest, Store
9+
from zarr.abc.store import BufferLike, ByteRequest, Store
1010
from zarr.storage._wrapper import WrapperStore
1111

1212
logger = logging.getLogger(__name__)
1313

1414
if TYPE_CHECKING:
15-
from zarr.core.buffer.core import Buffer, BufferPrototype
15+
from zarr.core.buffer.core import Buffer
1616

1717

1818
class CacheStore(WrapperStore[Store]):
@@ -218,7 +218,7 @@ def _remove_from_tracking(self, key: str) -> None:
218218
self._key_sizes.pop(key, None)
219219

220220
async def _get_try_cache(
221-
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
221+
self, key: str, prototype: BufferLike | None, byte_range: ByteRequest | None = None
222222
) -> Buffer | None:
223223
"""Try to get data from cache first, falling back to source store."""
224224
maybe_cached_result = await self._cache.get(key, prototype, byte_range)
@@ -246,7 +246,7 @@ async def _get_try_cache(
246246
return maybe_fresh_result
247247

248248
async def _get_no_cache(
249-
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
249+
self, key: str, prototype: BufferLike | None, byte_range: ByteRequest | None = None
250250
) -> Buffer | None:
251251
"""Get data directly from source store and update cache."""
252252
self._misses += 1
@@ -265,7 +265,7 @@ async def _get_no_cache(
265265
async def get(
266266
self,
267267
key: str,
268-
prototype: BufferPrototype,
268+
prototype: BufferLike | None = None,
269269
byte_range: ByteRequest | None = None,
270270
) -> Buffer | None:
271271
"""
@@ -275,8 +275,12 @@ async def get(
275275
----------
276276
key : str
277277
The key to retrieve
278-
prototype : BufferPrototype
279-
Buffer prototype for creating the result buffer
278+
prototype : BufferLike | None, optional
279+
The prototype of the output buffer.
280+
Can be either a Buffer class or an instance of `BufferPrototype`, in which the
281+
`buffer` attribute will be used.
282+
If `None`, the default buffer class for this store will be retrieved via the
283+
``_get_default_buffer_class`` method.
280284
byte_range : ByteRequest, optional
281285
Byte range to retrieve
282286

src/zarr/storage/_common.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import importlib.util
44
import json
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias
6+
from typing import Any, Literal, Self, TypeAlias
77

8-
from zarr.abc.store import ByteRequest, Store
9-
from zarr.core.buffer import Buffer, default_buffer_prototype
8+
from zarr.abc.store import BufferLike, ByteRequest, Store
9+
from zarr.core.buffer import Buffer
1010
from zarr.core.common import (
1111
ANY_ACCESS_MODE,
1212
ZARR_JSON,
@@ -26,9 +26,6 @@
2626
else:
2727
FSMap = None
2828

29-
if TYPE_CHECKING:
30-
from zarr.core.buffer import BufferPrototype
31-
3229

3330
def _dereference_path(root: str, path: str) -> str:
3431
if not isinstance(root, str):
@@ -145,16 +142,20 @@ async def open(cls, store: Store, path: str, mode: AccessModeLiteral | None = No
145142

146143
async def get(
147144
self,
148-
prototype: BufferPrototype | None = None,
145+
prototype: BufferLike | None = None,
149146
byte_range: ByteRequest | None = None,
150147
) -> Buffer | None:
151148
"""
152149
Read bytes from the store.
153150
154151
Parameters
155152
----------
156-
prototype : BufferPrototype, optional
157-
The buffer prototype to use when reading the bytes.
153+
prototype : BufferLike | None, optional
154+
The prototype of the output buffer.
155+
Can be either a Buffer class or an instance of `BufferPrototype`, in which the
156+
`buffer` attribute will be used.
157+
If `None`, the default buffer class for this store will be retrieved via the
158+
store's ``_get_default_buffer_class`` method.
158159
byte_range : ByteRequest, optional
159160
The range of bytes to read.
160161
@@ -164,7 +165,7 @@ async def get(
164165
The read bytes, or None if the key does not exist.
165166
"""
166167
if prototype is None:
167-
prototype = default_buffer_prototype()
168+
prototype = self.store._get_default_buffer_class()
168169
return await self.store.get(self.path, prototype=prototype, byte_range=byte_range)
169170

170171
async def set(self, value: Buffer) -> None:

0 commit comments

Comments
 (0)