Skip to content

Commit 63fd54b

Browse files
committed
add numcodecs protocol
1 parent a0c56fb commit 63fd54b

6 files changed

Lines changed: 165 additions & 53 deletions

File tree

src/zarr/abc/codec.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from typing import TYPE_CHECKING, Generic, TypeVar
4+
from collections.abc import Mapping
5+
from typing import TYPE_CHECKING, Generic, TypeGuard, TypeVar
6+
7+
from typing_extensions import ReadOnly, TypedDict
58

69
from zarr.abc.metadata import Metadata
710
from zarr.core.buffer import Buffer, NDBuffer
8-
from zarr.core.common import ChunkCoords, concurrent_map
11+
from zarr.core.common import ChunkCoords, NamedConfig, concurrent_map
912
from zarr.core.config import config
1013

1114
if TYPE_CHECKING:
@@ -34,6 +37,27 @@
3437
CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer)
3538
CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer)
3639

40+
TName = TypeVar("TName", bound=str, covariant=True)
41+
42+
43+
class CodecJSON_V2(TypedDict, Generic[TName]):
44+
"""The JSON representation of a codec for Zarr V2"""
45+
46+
id: ReadOnly[TName]
47+
48+
49+
def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]:
50+
return isinstance(data, Mapping) and "id" in data and isinstance(data["id"], str)
51+
52+
53+
CodecJSON_V3 = str | NamedConfig[str, Mapping[str, object]]
54+
"""The JSON representation of a codec for Zarr V3."""
55+
56+
# The widest type we will *accept* for a codec JSON
57+
# This covers v2 and v3
58+
CodecJSON = str | Mapping[str, object]
59+
"""The widest type of JSON-like input that could specify a codec."""
60+
3761

3862
class BaseCodec(Metadata, Generic[CodecInput, CodecOutput]):
3963
"""Generic base class for codecs.

src/zarr/codecs/_numcodecs.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numcodecs.registry as numcodecs_registry
2+
3+
from zarr.abc.codec import CodecJSON_V2
4+
from zarr.codecs._v2 import Numcodec
5+
6+
7+
def get_numcodec(data: CodecJSON_V2[str]) -> Numcodec:
8+
"""
9+
Resolve a numcodec codec from the numcodecs registry.
10+
11+
This requires the Numcodecs package to be installed.
12+
13+
Parameters
14+
----------
15+
data : CodecJSON_V2
16+
The JSON metadata for the codec.
17+
18+
Returns
19+
-------
20+
codec : Numcodec
21+
22+
Examples
23+
--------
24+
25+
>>> codec = get_numcodec({'id': 'zlib', 'level': 1})
26+
>>> codec
27+
Zlib(level=1)
28+
"""
29+
30+
codec_id = data["id"]
31+
cls = numcodecs_registry.codec_registry.get(codec_id)
32+
if cls is None and data in numcodecs_registry.entries:
33+
cls = numcodecs_registry.entries[data].load()
34+
numcodecs_registry.register_codec(cls, codec_id=data)
35+
if cls is not None:
36+
return cls.from_config({k: v for k, v in data.items() if k != "id"}) # type: ignore[no-any-return]
37+
raise KeyError(data)

src/zarr/codecs/_v2.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,77 @@
22

33
import asyncio
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, ClassVar, Self, TypeGuard
66

7-
import numcodecs
87
import numpy as np
98
from numcodecs.compat import ensure_bytes, ensure_ndarray_like
9+
from typing_extensions import Protocol
1010

11-
from zarr.abc.codec import ArrayBytesCodec
11+
from zarr.abc.codec import ArrayBytesCodec, CodecJSON_V2
1212
from zarr.registry import get_ndbuffer_class
1313

1414
if TYPE_CHECKING:
15-
import numcodecs.abc
16-
1715
from zarr.core.array_spec import ArraySpec
1816
from zarr.core.buffer import Buffer, NDBuffer
1917

2018

19+
class Numcodec(Protocol):
20+
"""
21+
A protocol that models the ``numcodecs.abc.Codec`` interface.
22+
"""
23+
24+
codec_id: ClassVar[str]
25+
26+
def encode(self, buf: Buffer | NDBuffer) -> Buffer | NDBuffer: ...
27+
28+
def decode(
29+
self, buf: Buffer | NDBuffer, out: Buffer | NDBuffer | None = None
30+
) -> Buffer | NDBuffer: ...
31+
32+
def get_config(self) -> CodecJSON_V2[str]: ...
33+
34+
@classmethod
35+
def from_config(cls, config: CodecJSON_V2[str]) -> Self: ...
36+
37+
38+
def _is_numcodec(obj: object) -> TypeGuard[Numcodec]:
39+
"""
40+
Check if the given object implements the Numcodec protocol.
41+
42+
The @runtime_checkable decorator does not allow issubclass checks for protocols with non-method
43+
members (i.e., attributes), so we use this function to manually check for the presence of the
44+
required attributes and methods on a given object.
45+
"""
46+
return _is_numcodec_cls(type(obj))
47+
48+
49+
def _is_numcodec_cls(obj: object) -> TypeGuard[type[Numcodec]]:
50+
"""
51+
Check if the given object is a class implements the Numcodec protocol.
52+
53+
The @runtime_checkable decorator does not allow issubclass checks for protocols with non-method
54+
members (i.e., attributes), so we use this function to manually check for the presence of the
55+
required attributes and methods on a given object.
56+
"""
57+
return (
58+
isinstance(obj, type)
59+
and hasattr(obj, "codec_id")
60+
and isinstance(obj.codec_id, str)
61+
and hasattr(obj, "encode")
62+
and callable(obj.encode)
63+
and hasattr(obj, "decode")
64+
and callable(obj.decode)
65+
and hasattr(obj, "get_config")
66+
and callable(obj.get_config)
67+
and hasattr(obj, "from_config")
68+
and callable(obj.from_config)
69+
)
70+
71+
2172
@dataclass(frozen=True)
2273
class V2Codec(ArrayBytesCodec):
23-
filters: tuple[numcodecs.abc.Codec, ...] | None
24-
compressor: numcodecs.abc.Codec | None
74+
filters: tuple[Numcodec, ...] | None
75+
compressor: Numcodec | None
2576

2677
is_fixed_size = False
2778

@@ -33,9 +84,9 @@ async def _decode_single(
3384
cdata = chunk_bytes.as_array_like()
3485
# decompress
3586
if self.compressor:
36-
chunk = await asyncio.to_thread(self.compressor.decode, cdata)
87+
chunk = await asyncio.to_thread(self.compressor.decode, cdata) # type: ignore[arg-type]
3788
else:
38-
chunk = cdata
89+
chunk = cdata # type: ignore[assignment]
3990

4091
# apply filters
4192
if self.filters:
@@ -56,7 +107,7 @@ async def _decode_single(
56107
# is an object array. In this case, we need to convert the object
57108
# array to the correct dtype.
58109

59-
chunk = np.array(chunk).astype(chunk_spec.dtype.to_native_dtype())
110+
chunk = np.array(chunk).astype(chunk_spec.dtype.to_native_dtype()) # type: ignore[assignment]
60111

61112
elif chunk.dtype != object:
62113
# If we end up here, someone must have hacked around with the filters.
@@ -85,17 +136,17 @@ async def _encode_single(
85136
# apply filters
86137
if self.filters:
87138
for f in self.filters:
88-
chunk = await asyncio.to_thread(f.encode, chunk)
139+
chunk = await asyncio.to_thread(f.encode, chunk) # type: ignore[arg-type]
89140

90141
# check object encoding
91142
if ensure_ndarray_like(chunk).dtype == object:
92143
raise RuntimeError("cannot write object array without object codec")
93144

94145
# compress
95146
if self.compressor:
96-
cdata = await asyncio.to_thread(self.compressor.encode, chunk)
147+
cdata = await asyncio.to_thread(self.compressor.encode, chunk) # type: ignore[arg-type]
97148
else:
98-
cdata = chunk
149+
cdata = chunk # type: ignore[assignment]
99150

100151
cdata = ensure_bytes(cdata)
101152
return chunk_spec.prototype.buffer.from_bytes(cdata)

src/zarr/core/_info.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from typing import TYPE_CHECKING, Literal
66

77
if TYPE_CHECKING:
8-
import numcodecs.abc
9-
108
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec
9+
from zarr.codecs._v2 import Numcodec
1110
from zarr.core.common import ZarrFormat
1211
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
1312

@@ -88,9 +87,9 @@ class ArrayInfo:
8887
_order: Literal["C", "F"]
8988
_read_only: bool
9089
_store_type: str
91-
_filters: tuple[numcodecs.abc.Codec, ...] | tuple[ArrayArrayCodec, ...] = ()
90+
_filters: tuple[Numcodec, ...] | tuple[ArrayArrayCodec, ...] = ()
9291
_serializer: ArrayBytesCodec | None = None
93-
_compressors: tuple[numcodecs.abc.Codec, ...] | tuple[BytesBytesCodec, ...] = ()
92+
_compressors: tuple[Numcodec, ...] | tuple[BytesBytesCodec, ...] = ()
9493
_count_bytes: int | None = None
9594
_count_bytes_stored: int | None = None
9695
_count_chunks_initialized: int | None = None

src/zarr/core/metadata/v2.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

33
import warnings
4-
from collections.abc import Iterable, Sequence
4+
from collections.abc import Iterable, Mapping, Sequence
55
from functools import cached_property
66
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast
77

8-
import numcodecs.abc
9-
8+
from zarr.abc.codec import CodecJSON_V2, _check_codecjson_v2
109
from zarr.abc.metadata import Metadata
10+
from zarr.codecs._numcodecs import get_numcodec
11+
from zarr.codecs._v2 import Numcodec, _is_numcodec
1112
from zarr.core.chunk_grids import RegularChunkGrid
1213
from zarr.core.dtype import get_data_type_from_json
1314
from zarr.core.dtype.common import OBJECT_CODEC_IDS, DTypeSpec_V2
@@ -30,7 +31,6 @@
3031
import json
3132
from dataclasses import dataclass, field, fields, replace
3233

33-
import numcodecs
3434
import numpy as np
3535

3636
from zarr.core.array_spec import ArrayConfig, ArraySpec
@@ -56,7 +56,7 @@ class ArrayV2MetadataDict(TypedDict):
5656

5757

5858
# Union of acceptable types for v2 compressors
59-
CompressorLikev2: TypeAlias = dict[str, JSON] | numcodecs.abc.Codec | None
59+
CompressorLike_V2: TypeAlias = CodecJSON_V2[str] | Numcodec
6060

6161

6262
@dataclass(frozen=True, kw_only=True)
@@ -66,9 +66,9 @@ class ArrayV2Metadata(Metadata):
6666
dtype: ZDType[TBaseDType, TBaseScalar]
6767
fill_value: int | float | str | bytes | None = None
6868
order: MemoryOrder = "C"
69-
filters: tuple[numcodecs.abc.Codec, ...] | None = None
69+
filters: tuple[Numcodec, ...] | None = None
7070
dimension_separator: Literal[".", "/"] = "."
71-
compressor: numcodecs.abc.Codec | None
71+
compressor: Numcodec | None
7272
attributes: dict[str, JSON] = field(default_factory=dict)
7373
zarr_format: Literal[2] = field(init=False, default=2)
7474

@@ -81,8 +81,8 @@ def __init__(
8181
fill_value: Any,
8282
order: MemoryOrder,
8383
dimension_separator: Literal[".", "/"] = ".",
84-
compressor: CompressorLikev2 = None,
85-
filters: Iterable[numcodecs.abc.Codec | dict[str, JSON]] | None = None,
84+
compressor: CompressorLike_V2 | None = None,
85+
filters: Iterable[CompressorLike_V2] | None = None,
8686
attributes: dict[str, JSON] | None = None,
8787
) -> None:
8888
"""
@@ -197,12 +197,12 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
197197

198198
def to_dict(self) -> dict[str, JSON]:
199199
zarray_dict = super().to_dict()
200-
if isinstance(zarray_dict["compressor"], numcodecs.abc.Codec):
200+
if _is_numcodec(zarray_dict["compressor"]):
201201
codec_config = zarray_dict["compressor"].get_config()
202202
# Hotfix for https://github.com/zarr-developers/zarr-python/issues/2647
203-
if codec_config["id"] == "zstd" and not codec_config.get("checksum", False):
204-
codec_config.pop("checksum")
205-
zarray_dict["compressor"] = codec_config
203+
if codec_config.get("id") == "zstd" and not codec_config.get("checksum", False):
204+
codec_config.pop("checksum") # type: ignore[typeddict-item]
205+
zarray_dict["compressor"] = codec_config # type: ignore[assignment]
206206

207207
if zarray_dict["filters"] is not None:
208208
raw_filters = zarray_dict["filters"]
@@ -212,11 +212,12 @@ def to_dict(self) -> dict[str, JSON]:
212212
raise TypeError("Invalid type for filters. Expected a list or tuple.")
213213
new_filters = []
214214
for f in raw_filters:
215-
if isinstance(f, numcodecs.abc.Codec):
215+
if _is_numcodec(f):
216216
new_filters.append(f.get_config())
217217
else:
218218
new_filters.append(f)
219-
zarray_dict["filters"] = new_filters
219+
# TODO: remove the type ignore when we property type the output
220+
zarray_dict["filters"] = new_filters # type: ignore[assignment]
220221

221222
# serialize the fill value after dtype-specific JSON encoding
222223
if self.fill_value is not None:
@@ -262,44 +263,44 @@ def parse_zarr_format(data: object) -> Literal[2]:
262263
raise ValueError(f"Invalid value. Expected 2. Got {data}.")
263264

264265

265-
def parse_filters(data: object) -> tuple[numcodecs.abc.Codec, ...] | None:
266+
def parse_filters(data: object) -> tuple[Numcodec, ...] | None:
266267
"""
267268
Parse a potential tuple of filters
268269
"""
269-
out: list[numcodecs.abc.Codec] = []
270+
out: list[Numcodec] = []
270271

271272
if data is None:
272273
return data
273274
if isinstance(data, Iterable):
274-
for idx, val in enumerate(data):
275-
if isinstance(val, numcodecs.abc.Codec):
275+
for val in data:
276+
if _is_numcodec(val):
276277
out.append(val)
277-
elif isinstance(val, dict):
278-
out.append(numcodecs.get_codec(val))
278+
if _check_codecjson_v2(val):
279+
out.append(get_numcodec(val))
279280
else:
280-
msg = f"Invalid filter at index {idx}. Expected a numcodecs.abc.Codec or a dict representation of numcodecs.abc.Codec. Got {type(val)} instead."
281+
msg = f'Invalid representation of Numcodec. Got {data}, expected a dict with an "id" key or a Numcodec instance.'
281282
raise TypeError(msg)
282283
if len(out) == 0:
283284
# Per the v2 spec, an empty tuple is not allowed -- use None to express "no filters"
284285
return None
285286
else:
286287
return tuple(out)
287288
# take a single codec instance and wrap it in a tuple
288-
if isinstance(data, numcodecs.abc.Codec):
289+
if _is_numcodec(data):
289290
return (data,)
290-
msg = f"Invalid filters. Expected None, an iterable of numcodecs.abc.Codec or dict representations of numcodecs.abc.Codec. Got {type(data)} instead."
291+
msg = f"Invalid filters. Expected None, an iterable of Numcodec or dict representations of Numcodec. Got {type(data)} instead."
291292
raise TypeError(msg)
292293

293294

294-
def parse_compressor(data: object) -> numcodecs.abc.Codec | None:
295+
def parse_compressor(data: object) -> Numcodec | None:
295296
"""
296297
Parse a potential compressor.
297298
"""
298-
if data is None or isinstance(data, numcodecs.abc.Codec):
299+
if data is None or _is_numcodec(data):
299300
return data
300-
if isinstance(data, dict):
301-
return numcodecs.get_codec(data)
302-
msg = f"Invalid compressor. Expected None, a numcodecs.abc.Codec, or a dict representation of a numcodecs.abc.Codec. Got {type(data)} instead."
301+
if _check_codecjson_v2(data):
302+
return get_numcodec(data)
303+
msg = f"Invalid compressor. Expected None, a Numcodec, or a dict representation of a Numcodec. Got {type(data)} instead."
303304
raise ValueError(msg)
304305

305306

@@ -313,7 +314,7 @@ def parse_metadata(data: ArrayV2Metadata) -> ArrayV2Metadata:
313314
return data
314315

315316

316-
def get_object_codec_id(maybe_object_codecs: Sequence[JSON]) -> str | None:
317+
def get_object_codec_id(maybe_object_codecs: Iterable[object]) -> str | None:
317318
"""
318319
Inspect a sequence of codecs / filters for an "object codec", i.e. a codec
319320
that can serialize object arrays to contiguous bytes. Zarr python
@@ -324,7 +325,7 @@ def get_object_codec_id(maybe_object_codecs: Sequence[JSON]) -> str | None:
324325
object_codec_id = None
325326
for maybe_object_codec in maybe_object_codecs:
326327
if (
327-
isinstance(maybe_object_codec, dict)
328+
isinstance(maybe_object_codec, Mapping)
328329
and maybe_object_codec.get("id") in OBJECT_CODEC_IDS
329330
):
330331
return cast("str", maybe_object_codec["id"])

0 commit comments

Comments
 (0)