Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 43 additions & 43 deletions src/zarr/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def register(self, cls: type[T], qualname: str | None = None) -> None:
self[qualname] = cls


__codec_registries: dict[str, Registry[Codec]] = defaultdict(Registry)
__pipeline_registry: Registry[CodecPipeline] = Registry()
__buffer_registry: Registry[Buffer] = Registry()
__ndbuffer_registry: Registry[NDBuffer] = Registry()
__chunk_key_encoding_registry: Registry[ChunkKeyEncoding] = Registry()
_codec_registries: dict[str, Registry[Codec]] = defaultdict(Registry)
_pipeline_registry: Registry[CodecPipeline] = Registry()
_buffer_registry: Registry[Buffer] = Registry()
_ndbuffer_registry: Registry[NDBuffer] = Registry()
_chunk_key_encoding_registry: Registry[ChunkKeyEncoding] = Registry()

"""
The registry module is responsible for managing implementations of codecs,
Expand Down Expand Up @@ -93,37 +93,37 @@ def _collect_entrypoints() -> list[Registry[Any]]:
"""
entry_points = get_entry_points()

__buffer_registry.lazy_load_list.extend(entry_points.select(group="zarr.buffer"))
__buffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="buffer"))
__ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr.ndbuffer"))
__ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="ndbuffer"))
_buffer_registry.lazy_load_list.extend(entry_points.select(group="zarr.buffer"))
_buffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="buffer"))
_ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr.ndbuffer"))
_ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="ndbuffer"))

data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr.data_type"))
data_type_registry._lazy_load_list.extend(entry_points.select(group="zarr", name="data_type"))

__chunk_key_encoding_registry.lazy_load_list.extend(
_chunk_key_encoding_registry.lazy_load_list.extend(
entry_points.select(group="zarr.chunk_key_encoding")
)
__chunk_key_encoding_registry.lazy_load_list.extend(
_chunk_key_encoding_registry.lazy_load_list.extend(
entry_points.select(group="zarr", name="chunk_key_encoding")
)

__pipeline_registry.lazy_load_list.extend(entry_points.select(group="zarr.codec_pipeline"))
__pipeline_registry.lazy_load_list.extend(
_pipeline_registry.lazy_load_list.extend(entry_points.select(group="zarr.codec_pipeline"))
_pipeline_registry.lazy_load_list.extend(
entry_points.select(group="zarr", name="codec_pipeline")
)
for e in entry_points.select(group="zarr.codecs"):
__codec_registries[e.name].lazy_load_list.append(e)
_codec_registries[e.name].lazy_load_list.append(e)
for group in entry_points.groups:
if group.startswith("zarr.codecs."):
codec_name = group.split(".")[2]
__codec_registries[codec_name].lazy_load_list.extend(entry_points.select(group=group))
_codec_registries[codec_name].lazy_load_list.extend(entry_points.select(group=group))
return [
*__codec_registries.values(),
__pipeline_registry,
__buffer_registry,
__ndbuffer_registry,
__chunk_key_encoding_registry,
*_codec_registries.values(),
_pipeline_registry,
_buffer_registry,
_ndbuffer_registry,
_chunk_key_encoding_registry,
]


Expand All @@ -137,36 +137,36 @@ def fully_qualified_name(cls: type) -> str:


def register_codec(key: str, codec_cls: type[Codec], *, qualname: str | None = None) -> None:
if key not in __codec_registries:
__codec_registries[key] = Registry()
__codec_registries[key].register(codec_cls, qualname=qualname)
if key not in _codec_registries:
_codec_registries[key] = Registry()
_codec_registries[key].register(codec_cls, qualname=qualname)


def register_pipeline(pipe_cls: type[CodecPipeline]) -> None:
__pipeline_registry.register(pipe_cls)
_pipeline_registry.register(pipe_cls)


def register_ndbuffer(cls: type[NDBuffer], qualname: str | None = None) -> None:
__ndbuffer_registry.register(cls, qualname)
_ndbuffer_registry.register(cls, qualname)


def register_buffer(cls: type[Buffer], qualname: str | None = None) -> None:
__buffer_registry.register(cls, qualname)
_buffer_registry.register(cls, qualname)


def register_chunk_key_encoding(key: str, cls: type) -> None:
__chunk_key_encoding_registry.register(cls, key)
_chunk_key_encoding_registry.register(cls, key)


def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]:
if reload_config:
_reload_config()

if key in __codec_registries:
if key in _codec_registries:
# logger.debug("Auto loading codec '%s' from entrypoint", codec_id)
__codec_registries[key].lazy_load()
_codec_registries[key].lazy_load()

codec_classes = __codec_registries[key]
codec_classes = _codec_registries[key]
if not codec_classes:
raise KeyError(key)
config_entry = config.get("codecs", {}).get(key)
Expand Down Expand Up @@ -257,50 +257,50 @@ def _parse_array_array_codec(data: dict[str, JSON] | Codec) -> ArrayArrayCodec:
def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]:
if reload_config:
_reload_config()
__pipeline_registry.lazy_load()
_pipeline_registry.lazy_load()
path = config.get("codec_pipeline.path")
pipeline_class = __pipeline_registry.get(path)
pipeline_class = _pipeline_registry.get(path)
if pipeline_class:
return pipeline_class
raise BadConfigError(
f"Pipeline class '{path}' not found in registered pipelines: {list(__pipeline_registry)}."
f"Pipeline class '{path}' not found in registered pipelines: {list(_pipeline_registry)}."
)


def get_buffer_class(reload_config: bool = False) -> type[Buffer]:
if reload_config:
_reload_config()
__buffer_registry.lazy_load()
_buffer_registry.lazy_load()

path = config.get("buffer")
buffer_class = __buffer_registry.get(path)
buffer_class = _buffer_registry.get(path)
if buffer_class:
return buffer_class
raise BadConfigError(
f"Buffer class '{path}' not found in registered buffers: {list(__buffer_registry)}."
f"Buffer class '{path}' not found in registered buffers: {list(_buffer_registry)}."
)


def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]:
if reload_config:
_reload_config()
__ndbuffer_registry.lazy_load()
_ndbuffer_registry.lazy_load()
path = config.get("ndbuffer")
ndbuffer_class = __ndbuffer_registry.get(path)
ndbuffer_class = _ndbuffer_registry.get(path)
if ndbuffer_class:
return ndbuffer_class
raise BadConfigError(
f"NDBuffer class '{path}' not found in registered buffers: {list(__ndbuffer_registry)}."
f"NDBuffer class '{path}' not found in registered buffers: {list(_ndbuffer_registry)}."
)


def get_chunk_key_encoding_class(key: str) -> type[ChunkKeyEncoding]:
__chunk_key_encoding_registry.lazy_load(use_entrypoint_name=True)
if key not in __chunk_key_encoding_registry:
_chunk_key_encoding_registry.lazy_load(use_entrypoint_name=True)
if key not in _chunk_key_encoding_registry:
raise KeyError(
f"Chunk key encoding '{key}' not found in registered chunk key encodings: {list(__chunk_key_encoding_registry)}."
f"Chunk key encoding '{key}' not found in registered chunk key encodings: {list(_chunk_key_encoding_registry)}."
)
return __chunk_key_encoding_registry[key]
return _chunk_key_encoding_registry[key]


_collect_entrypoints()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def test_parse_codecs_unknown_codec_raises(monkeypatch: pytest.MonkeyPatch) -> N
from zarr.registry import Registry

# to make sure the codec is always unknown (not sure if that's necessary)
monkeypatch.setattr(zarr.registry, "__codec_registries", defaultdict(Registry))
monkeypatch.setattr(zarr.registry, "_codec_registries", defaultdict(Registry))

codecs = [{"name": "unknown"}]
with pytest.raises(UnknownCodecError):
Expand Down
Loading