Skip to content

Commit 6904239

Browse files
committed
remove as much as possible default_buffer_prototype() invocation
1 parent 44c5882 commit 6904239

33 files changed

+220
-207
lines changed

src/zarr/abc/store.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,18 @@
88
from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable
99

1010
from zarr.core.buffer import Buffer, BufferPrototype
11-
from zarr.core.buffer.core import default_buffer_prototype
1211
from zarr.core.sync import sync
12+
from zarr.registry import get_buffer_class
1313

1414
if TYPE_CHECKING:
1515
from collections.abc import AsyncGenerator, AsyncIterator, Iterable
1616
from types import TracebackType
1717
from typing import Any, Self, TypeAlias
1818

19-
__all__ = ["BufferLike", "ByteGetter", "ByteSetter", "Store", "set_or_delete"]
19+
__all__ = ["BufferClassLike", "ByteGetter", "ByteSetter", "Store", "set_or_delete"]
2020

21-
BufferLike = type[Buffer] | BufferPrototype
21+
BufferClassLike = type[Buffer] | BufferPrototype
22+
"""An object that is or contains a Buffer class"""
2223

2324

2425
@dataclass
@@ -189,13 +190,13 @@ def _get_default_buffer_class(self) -> type[Buffer]:
189190
"""
190191
Get the default buffer class.
191192
"""
192-
return default_buffer_prototype().buffer
193+
return get_buffer_class()
193194

194195
@abstractmethod
195196
async def get(
196197
self,
197198
key: str,
198-
prototype: BufferLike | None = None,
199+
prototype: BufferClassLike | None = None,
199200
byte_range: ByteRequest | None = None,
200201
) -> Buffer | None:
201202
"""Retrieve the value associated with a given key.
@@ -225,7 +226,7 @@ async def _get_bytes(
225226
self,
226227
key: str,
227228
*,
228-
prototype: BufferLike | None = None,
229+
prototype: BufferClassLike | None = None,
229230
byte_range: ByteRequest | None = None,
230231
) -> bytes:
231232
"""
@@ -268,7 +269,7 @@ async def _get_bytes(
268269
--------
269270
>>> store = await MemoryStore.open()
270271
>>> await store.set("data", Buffer.from_bytes(b"hello world"))
271-
>>> data = await store._get_bytes("data", prototype=default_buffer_prototype())
272+
>>> data = await store._get_bytes("data")
272273
>>> print(data)
273274
b'hello world'
274275
"""
@@ -281,7 +282,7 @@ def _get_bytes_sync(
281282
self,
282283
key: str = "",
283284
*,
284-
prototype: BufferLike | None = None,
285+
prototype: BufferClassLike | None = None,
285286
byte_range: ByteRequest | None = None,
286287
) -> bytes:
287288
"""
@@ -329,7 +330,7 @@ def _get_bytes_sync(
329330
--------
330331
>>> store = MemoryStore()
331332
>>> await store.set("data", Buffer.from_bytes(b"hello world"))
332-
>>> data = store._get_bytes_sync("data", prototype=default_buffer_prototype())
333+
>>> data = store._get_bytes_sync("data")
333334
>>> print(data)
334335
b'hello world'
335336
"""
@@ -340,7 +341,7 @@ async def _get_json(
340341
self,
341342
key: str,
342343
*,
343-
prototype: BufferLike | None = None,
344+
prototype: BufferClassLike | None = None,
344345
byte_range: ByteRequest | None = None,
345346
) -> Any:
346347
"""
@@ -387,7 +388,7 @@ async def _get_json(
387388
>>> store = await MemoryStore.open()
388389
>>> metadata = {"zarr_format": 3, "node_type": "array"}
389390
>>> await store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode()))
390-
>>> data = await store._get_json("zarr.json", prototype=default_buffer_prototype())
391+
>>> data = await store._get_json("zarr.json")
391392
>>> print(data)
392393
{'zarr_format': 3, 'node_type': 'array'}
393394
"""
@@ -398,7 +399,7 @@ def _get_json_sync(
398399
self,
399400
key: str = "",
400401
*,
401-
prototype: BufferLike | None = None,
402+
prototype: BufferClassLike | None = None,
402403
byte_range: ByteRequest | None = None,
403404
) -> Any:
404405
"""
@@ -451,7 +452,7 @@ def _get_json_sync(
451452
>>> store = MemoryStore()
452453
>>> metadata = {"zarr_format": 3, "node_type": "array"}
453454
>>> store.set("zarr.json", Buffer.from_bytes(json.dumps(metadata).encode()))
454-
>>> data = store._get_json_sync("zarr.json", prototype=default_buffer_prototype())
455+
>>> data = store._get_json_sync("zarr.json")
455456
>>> print(data)
456457
{'zarr_format': 3, 'node_type': 'array'}
457458
"""
@@ -461,7 +462,7 @@ def _get_json_sync(
461462
@abstractmethod
462463
async def get_partial_values(
463464
self,
464-
prototype: BufferLike | None,
465+
prototype: BufferClassLike | None,
465466
key_ranges: Iterable[tuple[str, ByteRequest | None]],
466467
) -> list[Buffer | None]:
467468
"""Retrieve possibly partial values from given key_ranges.
@@ -645,7 +646,7 @@ def close(self) -> None:
645646
self._is_open = False
646647

647648
async def _get_many(
648-
self, requests: Iterable[tuple[str, BufferLike | None, ByteRequest | None]]
649+
self, requests: Iterable[tuple[str, BufferClassLike | None, ByteRequest | None]]
649650
) -> AsyncGenerator[tuple[str, Buffer | None], None]:
650651
"""
651652
Retrieve a collection of objects from storage. In general this method does not guarantee
@@ -676,10 +677,8 @@ async def getsize(self, key: str) -> int:
676677
# Note to implementers: this default implementation is very inefficient since
677678
# it requires reading the entire object. Many systems will have ways to get the
678679
# size of an object without reading it.
679-
# avoid circular import
680-
from zarr.core.buffer.core import default_buffer_prototype
681680

682-
value = await self.get(key, prototype=default_buffer_prototype())
681+
value = await self.get(key)
683682
if value is None:
684683
raise FileNotFoundError(key)
685684
return len(value)

src/zarr/core/common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,15 @@
2323

2424
from typing_extensions import ReadOnly
2525

26+
from zarr.core.buffer import Buffer, BufferPrototype
2627
from zarr.core.config import config as zarr_config
2728
from zarr.errors import ZarrRuntimeWarning
2829

2930
if TYPE_CHECKING:
3031
from collections.abc import Awaitable, Callable, Iterator
3132

33+
from zarr.abc.store import BufferClassLike
34+
3235

3336
ZARR_JSON = "zarr.json"
3437
ZARRAY_JSON = ".zarray"
@@ -246,3 +249,17 @@ def _warn_order_kwarg() -> None:
246249
def _default_zarr_format() -> ZarrFormat:
247250
"""Return the default zarr_version"""
248251
return cast("ZarrFormat", int(zarr_config.get("default_zarr_format", 3)))
252+
253+
254+
def parse_bufferclasslike(obj: BufferClassLike | None) -> type[Buffer]:
255+
"""
256+
Take an optional BufferClassLike and return a Buffer class
257+
"""
258+
# Avoid a circular import. Temporary fix until we re-organize modules appropriately.
259+
from zarr.registry import get_buffer_class
260+
261+
if obj is None:
262+
return get_buffer_class()
263+
if isinstance(obj, BufferPrototype):
264+
return obj.buffer
265+
return obj

src/zarr/core/group.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import zarr.api.asynchronous as async_api
1919
from zarr.abc.metadata import Metadata
20-
from zarr.abc.store import Store, set_or_delete
20+
from zarr.abc.store import BufferClassLike, Store, set_or_delete
2121
from zarr.core._info import GroupInfo
2222
from zarr.core.array import (
2323
DEFAULT_FILL_VALUE,
@@ -32,7 +32,6 @@
3232
create_array,
3333
)
3434
from zarr.core.attributes import Attributes
35-
from zarr.core.buffer import default_buffer_prototype
3635
from zarr.core.common import (
3736
JSON,
3837
ZARR_JSON,
@@ -44,6 +43,7 @@
4443
NodeType,
4544
ShapeLike,
4645
ZarrFormat,
46+
parse_bufferclasslike,
4747
parse_shapelike,
4848
)
4949
from zarr.core.config import config
@@ -75,7 +75,7 @@
7575
from typing import Any
7676

7777
from zarr.core.array_spec import ArrayConfigLike
78-
from zarr.core.buffer import Buffer, BufferPrototype
78+
from zarr.core.buffer import Buffer
7979
from zarr.core.chunk_key_encodings import ChunkKeyEncodingLike
8080
from zarr.core.common import MemoryOrder
8181
from zarr.core.dtype import ZDTypeLike
@@ -356,20 +356,25 @@ class GroupMetadata(Metadata):
356356
consolidated_metadata: ConsolidatedMetadata | None = None
357357
node_type: Literal["group"] = field(default="group", init=False)
358358

359-
def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
359+
def to_buffer_dict(self, prototype: BufferClassLike | None = None) -> dict[str, Buffer]:
360+
"""
361+
Convert the metadata document to a dict with string keys and `Buffer` values.
362+
"""
363+
buffer_cls = parse_bufferclasslike(prototype)
364+
360365
json_indent = config.get("json_indent")
361366
if self.zarr_format == 3:
362367
return {
363-
ZARR_JSON: prototype.buffer.from_bytes(
368+
ZARR_JSON: buffer_cls.from_bytes(
364369
json.dumps(self.to_dict(), indent=json_indent, allow_nan=True).encode()
365370
)
366371
}
367372
else:
368373
items = {
369-
ZGROUP_JSON: prototype.buffer.from_bytes(
374+
ZGROUP_JSON: buffer_cls.from_bytes(
370375
json.dumps({"zarr_format": self.zarr_format}, indent=json_indent).encode()
371376
),
372-
ZATTRS_JSON: prototype.buffer.from_bytes(
377+
ZATTRS_JSON: buffer_cls.from_bytes(
373378
json.dumps(self.attributes, indent=json_indent, allow_nan=True).encode()
374379
),
375380
}
@@ -396,7 +401,7 @@ def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
396401
},
397402
}
398403

399-
items[ZMETADATA_V2_JSON] = prototype.buffer.from_bytes(
404+
items[ZMETADATA_V2_JSON] = buffer_cls.from_bytes(
400405
json.dumps(
401406
{"metadata": d, "zarr_consolidated_format": 1}, allow_nan=True
402407
).encode()
@@ -2029,7 +2034,7 @@ async def update_attributes_async(self, new_attributes: dict[str, Any]) -> Group
20292034
new_metadata = replace(self.metadata, attributes=new_attributes)
20302035

20312036
# Write new metadata
2032-
to_save = new_metadata.to_buffer_dict(default_buffer_prototype())
2037+
to_save = new_metadata.to_buffer_dict()
20332038
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]
20342039
await asyncio.gather(*awaitables)
20352040

@@ -3615,9 +3620,7 @@ async def _read_metadata_v3(store: Store, path: str) -> ArrayV3Metadata | GroupM
36153620
document stored at store_path.path / zarr.json. If no such document is found, raise a
36163621
FileNotFoundError.
36173622
"""
3618-
zarr_json_bytes = await store.get(
3619-
_join_paths([path, ZARR_JSON]), prototype=default_buffer_prototype()
3620-
)
3623+
zarr_json_bytes = await store.get(_join_paths([path, ZARR_JSON]))
36213624
if zarr_json_bytes is None:
36223625
raise FileNotFoundError(path)
36233626
else:
@@ -3634,9 +3637,9 @@ async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupM
36343637
# TODO: consider first fetching array metadata, and only fetching group metadata when we don't
36353638
# find an array
36363639
zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather(
3637-
store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()),
3638-
store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()),
3639-
store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()),
3640+
store.get(_join_paths([path, ZARRAY_JSON])),
3641+
store.get(_join_paths([path, ZGROUP_JSON])),
3642+
store.get(_join_paths([path, ZATTRS_JSON])),
36403643
)
36413644

36423645
if zattrs_bytes is None:
@@ -3850,7 +3853,7 @@ def _persist_metadata(
38503853
Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited.
38513854
"""
38523855

3853-
to_save = metadata.to_buffer_dict(default_buffer_prototype())
3856+
to_save = metadata.to_buffer_dict()
38543857
return tuple(
38553858
_set_return_key(store=store, key=_join_paths([path, key]), value=value, semaphore=semaphore)
38563859
for key, value in to_save.items()

src/zarr/core/metadata/io.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import TYPE_CHECKING
55

66
from zarr.abc.store import set_or_delete
7-
from zarr.core.buffer.core import default_buffer_prototype
87
from zarr.errors import ContainsArrayError
98
from zarr.storage._common import StorePath, ensure_no_existing_node
109

@@ -51,7 +50,7 @@ async def save_metadata(
5150
------
5251
ValueError
5352
"""
54-
to_save = metadata.to_buffer_dict(default_buffer_prototype())
53+
to_save = metadata.to_buffer_dict()
5554
set_awaitables = [set_or_delete(store_path / key, value) for key, value in to_save.items()]
5655

5756
if ensure_parents:
@@ -71,9 +70,7 @@ async def save_metadata(
7170
set_awaitables.extend(
7271
[
7372
(parent_store_path / key).set_if_not_exists(value)
74-
for key, value in parent_metadata.to_buffer_dict(
75-
default_buffer_prototype()
76-
).items()
73+
for key, value in parent_metadata.to_buffer_dict().items()
7774
]
7875
)
7976

src/zarr/core/metadata/v2.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy.typing as npt
2020

21+
from zarr.abc.store import BufferClassLike
2122
from zarr.core.buffer import Buffer, BufferPrototype
2223
from zarr.core.dtype.wrapper import (
2324
TBaseDType,
@@ -39,6 +40,7 @@
3940
ZARRAY_JSON,
4041
ZATTRS_JSON,
4142
MemoryOrder,
43+
parse_bufferclasslike,
4244
parse_shapelike,
4345
)
4446
from zarr.core.config import config, parse_indexing_order
@@ -125,15 +127,16 @@ def chunk_grid(self) -> RegularChunkGrid:
125127
def shards(self) -> tuple[int, ...] | None:
126128
return None
127129

128-
def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
130+
def to_buffer_dict(self, prototype: BufferClassLike | None = None) -> dict[str, Buffer]:
131+
buffer_cls = parse_bufferclasslike(prototype)
129132
zarray_dict = self.to_dict()
130133
zattrs_dict = zarray_dict.pop("attributes", {})
131134
json_indent = config.get("json_indent")
132135
return {
133-
ZARRAY_JSON: prototype.buffer.from_bytes(
136+
ZARRAY_JSON: buffer_cls.from_bytes(
134137
json.dumps(zarray_dict, indent=json_indent, allow_nan=True).encode()
135138
),
136-
ZATTRS_JSON: prototype.buffer.from_bytes(
139+
ZATTRS_JSON: buffer_cls.from_bytes(
137140
json.dumps(zattrs_dict, indent=json_indent, allow_nan=True).encode()
138141
),
139142
}

src/zarr/core/metadata/v3.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
if TYPE_CHECKING:
1212
from typing import Self
1313

14+
from zarr.abc.store import BufferClassLike
1415
from zarr.core.buffer import Buffer, BufferPrototype
1516
from zarr.core.chunk_grids import ChunkGrid
1617
from zarr.core.common import JSON
@@ -35,6 +36,7 @@
3536
ZARR_JSON,
3637
DimensionNames,
3738
NamedConfig,
39+
parse_bufferclasslike,
3840
parse_named_configuration,
3941
parse_shapelike,
4042
)
@@ -345,11 +347,12 @@ def get_chunk_spec(
345347
def encode_chunk_key(self, chunk_coords: tuple[int, ...]) -> str:
346348
return self.chunk_key_encoding.encode_chunk_key(chunk_coords)
347349

348-
def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
350+
def to_buffer_dict(self, prototype: BufferClassLike | None = None) -> dict[str, Buffer]:
351+
buffer_cls = parse_bufferclasslike(prototype)
349352
json_indent = config.get("json_indent")
350353
d = self.to_dict()
351354
return {
352-
ZARR_JSON: prototype.buffer.from_bytes(
355+
ZARR_JSON: buffer_cls.from_bytes(
353356
json.dumps(d, allow_nan=True, indent=json_indent).encode()
354357
)
355358
}

0 commit comments

Comments
 (0)