Skip to content

Commit 1fd5e86

Browse files
committed
wip
1 parent 7b68dab commit 1fd5e86

3 files changed

Lines changed: 126 additions & 29 deletions

File tree

src/zarr/codecs/numcodec.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import asyncio
88
from collections.abc import Mapping
99
from dataclasses import dataclass
10-
from typing import TYPE_CHECKING, Literal, Self, overload
10+
from typing import TYPE_CHECKING, Callable, Literal, Self, TypeGuard, overload
1111

1212
import numpy as np
1313
from typing_extensions import Protocol, runtime_checkable
@@ -49,6 +49,26 @@ def get_config(self) -> CodecConfig_V2[str]: ...
4949
@classmethod
5050
def from_config(cls, config: CodecConfig_V2[str]) -> Self: ...
5151

52+
def is_numcodec_cls(obj: object) -> TypeGuard[type[Numcodec]]:
53+
"""
54+
Check if the given object implements the Numcodec protocol. Because the @runtime_checkable
55+
decorator does not allow issubclass checks for protocols with non-method members (i.e., attributes),
56+
we need to manually check for the presence of the required attributes and methods.
57+
"""
58+
return (
59+
isinstance(obj, type) and
60+
hasattr(obj, "codec_id") and
61+
isinstance(obj.codec_id, str) and
62+
hasattr(obj, "encode") and
63+
callable(obj.encode) and
64+
hasattr(obj, "decode") and
65+
callable(obj.decode) and
66+
hasattr(obj, "get_config") and
67+
callable(obj.get_config) and
68+
hasattr(obj, "from_config") and
69+
callable(obj.from_config)
70+
)
71+
5272

5373
@dataclass(frozen=True, kw_only=True)
5474
class NumcodecsAdapter:
@@ -104,7 +124,7 @@ async def _encode_single(self, chunk_data: Buffer, chunk_spec: ArraySpec) -> Buf
104124

105125

106126
@dataclass(kw_only=True, frozen=True)
107-
class NumcodecsArrayCodec(NumcodecsAdapter, ArrayArrayCodec):
127+
class NumcodecsArrayArrayCodec(NumcodecsAdapter, ArrayArrayCodec):
108128
async def _decode_single(self, chunk_data: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer:
109129
chunk_ndarray = chunk_data.as_ndarray_like()
110130
out = await asyncio.to_thread(self._codec.decode, chunk_ndarray)

src/zarr/core/metadata/v2.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from zarr.abc.codec import Codec
1111
from zarr.abc.metadata import Metadata
12-
from zarr.codecs.numcodec import NumcodecsAdapter
12+
from zarr.codecs.numcodec import Numcodec, NumcodecsAdapter
1313
from zarr.core.chunk_grids import RegularChunkGrid
1414
from zarr.core.dtype import get_data_type_from_json
1515
from zarr.core.dtype.common import OBJECT_CODEC_IDS, DTypeSpec_V2
@@ -199,29 +199,21 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
199199

200200
def to_dict(self) -> dict[str, JSON]:
201201
zarray_dict = super().to_dict()
202-
if isinstance(zarray_dict["compressor"], numcodecs.abc.Codec):
202+
if isinstance(zarray_dict["compressor"], Numcodec):
203+
raise ValueError('raw numcodecs codecs are not allowed.')
203204
codec_config = zarray_dict["compressor"].get_config()
204205
# Hotfix for https://github.com/zarr-developers/zarr-python/issues/2647
205206
if codec_config["id"] == "zstd" and not codec_config.get("checksum", False):
206207
codec_config.pop("checksum")
207208
zarray_dict["compressor"] = codec_config
208209

209-
if isinstance(zarray_dict["compressor"], NumcodecsAdapter):
210-
zarray_dict["compressor"] = zarray_dict["compressor"].to_json(zarr_format=2)
211-
210+
zarray_dict["compressor"] = self.compressor.to_json(zarr_format=2)
211+
new_filters = []
212212
if zarray_dict["filters"] is not None:
213-
raw_filters = zarray_dict["filters"]
214-
# TODO: remove this when we can stratically type the output JSON data structure
215-
# entirely
216-
if not isinstance(raw_filters, list | tuple):
217-
raise TypeError("Invalid type for filters. Expected a list or tuple.")
218-
new_filters = []
219-
for f in raw_filters:
220-
if isinstance(f, numcodecs.abc.Codec):
221-
new_filters.append(f.get_config())
222-
else:
223-
new_filters.append(f)
224-
zarray_dict["filters"] = new_filters
213+
new_filters.append(f.to_json(zarr_format=2))
214+
else:
215+
new_filters = None
216+
zarray_dict["filters"] = new_filters
225217

226218
# serialize the fill value after dtype-specific JSON encoding
227219
if self.fill_value is not None:

src/zarr/registry.py

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
if TYPE_CHECKING:
1212
from importlib.metadata import EntryPoint
13-
13+
from zarr.codecs.numcodec import Numcodec
1414
from zarr.abc.codec import (
1515
ArrayArrayCodec,
1616
ArrayBytesCodec,
@@ -53,6 +53,10 @@ def register(self, cls: type[T], qualname: str | None = None) -> None:
5353
self[qualname] = cls
5454

5555

56+
__filter_registries: dict[str, Registry[ArrayArrayCodec]] = defaultdict(Registry)
57+
__serializer_registries: dict[str, Registry[ArrayBytesCodec]] = defaultdict(Registry)
58+
__compressor_registries: dict[str, Registry[BytesBytesCodec]] = defaultdict(Registry)
59+
5660
__codec_registries: dict[str, Registry[Codec]] = defaultdict(Registry)
5761
__pipeline_registry: Registry[CodecPipeline] = Registry()
5862
__buffer_registry: Registry[Buffer] = Registry()
@@ -117,17 +121,59 @@ def _collect_entrypoints() -> list[Registry[Any]]:
117121
def _reload_config() -> None:
118122
config.refresh()
119123

120-
121124
def fully_qualified_name(cls: type) -> str:
122125
module = cls.__module__
123126
return module + "." + cls.__qualname__
124127

128+
def register_filter(key: str, codec_cls: type[ArrayArrayCodec]) -> None:
129+
if key not in __filter_registries:
130+
__filter_registries[key] = Registry()
131+
__filter_registries[key].register(codec_cls)
132+
133+
def register_serializer(key: str, codec_cls: type[ArrayBytesCodec]) -> None:
134+
from zarr.codecs.numcodec import NumcodecsArrayBytesCodec, is_numcodec_cls
135+
if is_numcodec_cls(codec_cls):
136+
_codec_cls = NumcodecsArrayBytesCodec(_codec=codec_cls)
137+
else:
138+
_codec_cls = codec_cls
139+
if key not in __serializer_registries:
140+
__serializer_registries[key] = Registry()
141+
__serializer_registries[key].register(_codec_cls)
142+
143+
def register_serializer(key: str, codec_cls: type[ArrayBytesCodec]) -> None:
144+
from zarr.codecs.numcodec import NumcodecsArrayBytesCodec, is_numcodec_cls
145+
if is_numcodec_cls(codec_cls):
146+
_codec_cls = NumcodecsArrayBytesCodec(_codec=codec_cls)
147+
else:
148+
_codec_cls = codec_cls
149+
if key not in __serializer_registries:
150+
__serializer_registries[key] = Registry()
151+
__serializer_registries[key].register(_codec_cls)
152+
153+
def register_compressor(key: str, codec_cls: type[BytesBytesCodec | Numcodec]) -> None:
154+
from zarr.codecs.numcodec import NumcodecsBytesBytesCodec, is_numcodec_cls
155+
if is_numcodec_cls(codec_cls):
156+
_codec_cls = NumcodecsBytesBytesCodec(_codec=codec_cls)
157+
else:
158+
_codec_cls = codec_cls
159+
if key not in __compressor_registries:
160+
__compressor_registries[key] = Registry()
161+
__compressor_registries[key].register(_codec_cls)
125162

126163
def register_codec(key: str, codec_cls: type[Codec]) -> None:
164+
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec
165+
if issubclass(codec_cls, ArrayBytesCodec):
166+
register_serializer(key, codec_cls)
167+
elif issubclass(codec_cls, ArrayArrayCodec):
168+
register_filter(key, codec_cls)
169+
else:
170+
register_compressor(key, codec_cls)
171+
172+
"""
127173
if key not in __codec_registries:
128174
__codec_registries[key] = Registry()
129175
__codec_registries[key].register(codec_cls)
130-
176+
"""
131177

132178
def register_pipeline(pipe_cls: type[CodecPipeline]) -> None:
133179
__pipeline_registry.register(pipe_cls)
@@ -140,6 +186,41 @@ def register_ndbuffer(cls: type[NDBuffer], qualname: str | None = None) -> None:
140186
def register_buffer(cls: type[Buffer], qualname: str | None = None) -> None:
141187
__buffer_registry.register(cls, qualname)
142188

189+
def get_filter_class(key: str, reload_config: bool = False) -> type[ArrayArrayCodec]:
190+
return _get_codec_class(key, __serializer_registries, reload_config=reload_config)
191+
192+
def get_serializer_class(key: str, reload_config: bool = False) -> type[ArrayBytesCodec]:
193+
return _get_codec_class(key, __serializer_registries, reload_config=reload_config)
194+
195+
def get_compressor_class(key: str, reload_config: bool = False) -> type[BytesBytesCodec]:
196+
return _get_codec_class(key, __compressor_registries, reload_config=reload_config)
197+
198+
def _get_codec_class(key: str, registry: dict[str, Registry[Codec]], *, reload_config: bool = False) -> type[Codec]:
199+
if reload_config:
200+
_reload_config()
201+
202+
if key in registry:
203+
# logger.debug("Auto loading codec '%s' from entrypoint", codec_id)
204+
registry[key].lazy_load()
205+
206+
codec_classes = registry[key]
207+
if not codec_classes:
208+
raise KeyError(key)
209+
210+
config_entry = config.get("codecs", {}).get(key)
211+
if config_entry is None:
212+
if len(codec_classes) == 1:
213+
return next(iter(codec_classes.values()))
214+
warnings.warn(
215+
f"Codec '{key}' not configured in config. Selecting any implementation.",
216+
stacklevel=2,
217+
)
218+
return list(codec_classes.values())[-1]
219+
selected_codec_cls = codec_classes[config_entry]
220+
221+
if selected_codec_cls:
222+
return selected_codec_cls
223+
raise KeyError(key)
143224

144225
def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]:
145226
if reload_config:
@@ -189,7 +270,7 @@ def _parse_bytes_bytes_codec(data: dict[str, JSON] | Codec | Numcodec) -> BytesB
189270

190271
result: BytesBytesCodec
191272
if isinstance(data, dict):
192-
result = _resolve_codec(data)
273+
result = get_compressor_class(data["name"]).from_dict(data)
193274
if not isinstance(result, BytesBytesCodec):
194275
msg = f"Expected a dict representation of a BytesBytesCodec; got a dict representation of a {type(result)} instead."
195276
raise TypeError(msg)
@@ -202,39 +283,43 @@ def _parse_bytes_bytes_codec(data: dict[str, JSON] | Codec | Numcodec) -> BytesB
202283
return result
203284

204285

205-
def _parse_array_bytes_codec(data: dict[str, JSON] | Codec) -> ArrayBytesCodec:
286+
def _parse_array_bytes_codec(data: dict[str, JSON] | Codec | Numcodec) -> ArrayBytesCodec:
206287
"""
207288
Normalize the input to a ``ArrayBytesCodec`` instance.
208289
If the input is already a ``ArrayBytesCodec``, it is returned as is. If the input is a dict, it
209290
is converted to a ``ArrayBytesCodec`` instance via the ``_resolve_codec`` function.
210291
"""
211292
from zarr.abc.codec import ArrayBytesCodec
212-
293+
from zarr.codecs.numcodec import Numcodec, NumcodecsArrayBytesCodec
213294
if isinstance(data, dict):
214-
result = _resolve_codec(data)
295+
result = get_serializer_class(data["name"]).from_dict(data)
215296
if not isinstance(result, ArrayBytesCodec):
216297
msg = f"Expected a dict representation of a ArrayBytesCodec; got a dict representation of a {type(result)} instead."
217298
raise TypeError(msg)
299+
elif isinstance(data, Numcodec):
300+
return NumcodecsArrayBytesCodec(_codec=data)
218301
else:
219302
if not isinstance(data, ArrayBytesCodec):
220303
raise TypeError(f"Expected a ArrayBytesCodec. Got {type(data)} instead.")
221304
result = data
222305
return result
223306

224307

225-
def _parse_array_array_codec(data: dict[str, JSON] | Codec) -> ArrayArrayCodec:
308+
def _parse_array_array_codec(data: dict[str, JSON] | Codec | Numcodec) -> ArrayArrayCodec:
226309
"""
227310
Normalize the input to a ``ArrayArrayCodec`` instance.
228311
If the input is already a ``ArrayArrayCodec``, it is returned as is. If the input is a dict, it
229312
is converted to a ``ArrayArrayCodec`` instance via the ``_resolve_codec`` function.
230313
"""
231314
from zarr.abc.codec import ArrayArrayCodec
232-
315+
from zarr.codecs.numcodec import Numcodec, NumcodecsArrayArrayCodec
233316
if isinstance(data, dict):
234-
result = _resolve_codec(data)
317+
result = get_filter_class(data["name"]).from_dict(data)
235318
if not isinstance(result, ArrayArrayCodec):
236319
msg = f"Expected a dict representation of a ArrayArrayCodec; got a dict representation of a {type(result)} instead."
237320
raise TypeError(msg)
321+
elif isinstance(data, Numcodec):
322+
return NumcodecsArrayArrayCodec(_codec=data)
238323
else:
239324
if not isinstance(data, ArrayArrayCodec):
240325
raise TypeError(f"Expected a ArrayArrayCodec. Got {type(data)} instead.")

0 commit comments

Comments
 (0)