Skip to content

Commit 9269120

Browse files
committed
feat: define the codec name: codec class mapping as part of array runtime configuration
1 parent 5e8eeaa commit 9269120

13 files changed

Lines changed: 239 additions & 62 deletions

File tree

src/zarr/codecs/sharding.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from zarr.codecs.bytes import BytesCodec
2828
from zarr.codecs.crc32c_ import Crc32cCodec
29-
from zarr.core.array_spec import ArrayConfig, ArraySpec
29+
from zarr.core.array_spec import ArraySpec, ArraySpecConfig, parse_codec_class_map
3030
from zarr.core.buffer import (
3131
Buffer,
3232
BufferPrototype,
@@ -319,10 +319,13 @@ def __init__(
319319
codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(),),
320320
index_codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(), Crc32cCodec()),
321321
index_location: ShardingCodecIndexLocation | str = ShardingCodecIndexLocation.end,
322+
codec_class_map: Mapping[str, type[Codec]] | None = None,
322323
) -> None:
324+
if codec_class_map is None:
325+
codec_class_map = parse_codec_class_map(None)
323326
chunk_shape_parsed = parse_shapelike(chunk_shape)
324-
codecs_parsed = parse_codecs(codecs)
325-
index_codecs_parsed = parse_codecs(index_codecs)
327+
codecs_parsed = parse_codecs(codecs, codec_class_map=codec_class_map)
328+
index_codecs_parsed = parse_codecs(index_codecs, codec_class_map=codec_class_map)
326329
index_location_parsed = parse_index_location(index_location)
327330

328331
object.__setattr__(self, "chunk_shape", chunk_shape_parsed)
@@ -737,7 +740,7 @@ def _get_index_chunk_spec(self, chunks_per_shard: tuple[int, ...]) -> ArraySpec:
737740
shape=chunks_per_shard + (2,),
738741
dtype=UInt64(endianness="little"),
739742
fill_value=MAX_UINT_64,
740-
config=ArrayConfig(
743+
config=ArraySpecConfig(
741744
order="C", write_empty_chunks=False
742745
), # Note: this is hard-coded for simplicity -- it is not surfaced into user code,
743746
prototype=default_buffer_prototype(),

src/zarr/core/array.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@
2828
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
2929
from zarr.codecs.zstd import ZstdCodec
3030
from zarr.core._info import ArrayInfo
31-
from zarr.core.array_spec import ArrayConfig, ArrayConfigLike, ArraySpec, parse_array_config
31+
from zarr.core.array_spec import (
32+
ArrayConfig,
33+
ArrayConfigLike,
34+
ArraySpec,
35+
ArraySpecConfig,
36+
parse_array_config,
37+
)
3238
from zarr.core.attributes import Attributes
3339
from zarr.core.buffer import (
3440
BufferPrototype,
@@ -197,13 +203,13 @@ def _chunk_sizes_from_shape(
197203
return tuple(result)
198204

199205

200-
def parse_array_metadata(data: Any) -> ArrayMetadata:
206+
def parse_array_metadata(data: object, config: ArrayConfig) -> ArrayMetadata:
201207
if isinstance(data, ArrayMetadata):
202-
return data
208+
return data.with_config(config)
203209
elif isinstance(data, dict):
204210
zarr_format = data.get("zarr_format")
205211
if zarr_format == 3:
206-
meta_out = ArrayV3Metadata.from_dict(data)
212+
meta_out = ArrayV3Metadata.from_dict(data, config=config)
207213
if len(meta_out.storage_transformers) > 0:
208214
msg = (
209215
f"Array metadata contains storage transformers: {meta_out.storage_transformers}."
@@ -212,7 +218,7 @@ def parse_array_metadata(data: Any) -> ArrayMetadata:
212218
raise ValueError(msg)
213219
return meta_out
214220
elif zarr_format == 2:
215-
return ArrayV2Metadata.from_dict(data)
221+
return ArrayV2Metadata.from_dict(data, config=config)
216222
else:
217223
raise ValueError(f"Invalid zarr_format: {zarr_format}. Expected 2 or 3")
218224
raise TypeError # pragma: no cover
@@ -353,8 +359,8 @@ def __init__(
353359
store_path: StorePath,
354360
config: ArrayConfigLike | None = None,
355361
) -> None:
356-
metadata_parsed = parse_array_metadata(metadata)
357362
config_parsed = parse_array_config(config)
363+
metadata_parsed = parse_array_metadata(metadata, config=config_parsed)
358364

359365
object.__setattr__(self, "metadata", metadata_parsed)
360366
object.__setattr__(self, "store_path", store_path)
@@ -5769,11 +5775,16 @@ def _get_chunk_spec(
57695775
spec = chunk_grid[chunk_coords]
57705776
if spec is None:
57715777
raise IndexError(f"Chunk coordinates {chunk_coords} are out of bounds.")
5778+
spec_config = ArraySpecConfig(
5779+
order=array_config.order,
5780+
read_missing_chunks=array_config.read_missing_chunks,
5781+
write_empty_chunks=array_config.write_empty_chunks,
5782+
)
57725783
return ArraySpec(
57735784
shape=spec.codec_shape,
57745785
dtype=metadata.dtype,
57755786
fill_value=metadata.fill_value,
5776-
config=array_config,
5787+
config=spec_config,
57775788
prototype=prototype,
57785789
)
57795790

src/zarr/core/array_spec.py

Lines changed: 155 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

3-
from collections.abc import Mapping
43
from dataclasses import dataclass, fields
5-
from typing import TYPE_CHECKING, Any, Literal, Self, TypedDict, cast
4+
from typing import TYPE_CHECKING, Any, Final, Literal, Self, TypedDict, cast
65

6+
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
77
from zarr.core.common import (
88
MemoryOrder,
99
parse_bool,
@@ -14,13 +14,35 @@
1414
from zarr.core.config import config as zarr_config
1515

1616
if TYPE_CHECKING:
17+
from collections.abc import Mapping
1718
from typing import NotRequired
1819

1920
from zarr.core.buffer import BufferPrototype
2021
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
2122

2223

24+
class CodecPipelineRequest(TypedDict):
25+
"""
26+
A dictionary model of a request for a codec pipeline.
27+
"""
28+
29+
class_path: str
30+
options: NotRequired[dict[str, object]]
31+
32+
2333
class ArrayConfigParams(TypedDict):
34+
"""
35+
A TypedDict model of the attributes of an ArrayConfig class.
36+
"""
37+
38+
order: MemoryOrder
39+
write_empty_chunks: bool
40+
read_missing_chunks: bool
41+
codec_class_map: Mapping[str, object]
42+
codec_pipeline_class: CodecPipelineRequest
43+
44+
45+
class ArrayConfigRequest(TypedDict):
2446
"""
2547
A TypedDict model of the attributes of an ArrayConfig class, but with no required fields.
2648
This allows for partial construction of an ArrayConfig, with the assumption that the unset
@@ -30,6 +52,29 @@ class ArrayConfigParams(TypedDict):
3052
order: NotRequired[MemoryOrder]
3153
write_empty_chunks: NotRequired[bool]
3254
read_missing_chunks: NotRequired[bool]
55+
codec_class_map: NotRequired[
56+
Mapping[str, type[ArrayArrayCodec | ArrayBytesCodec | BytesBytesCodec]]
57+
]
58+
codec_pipeline_class: NotRequired[CodecPipelineRequest]
59+
60+
61+
ArrayConfigKeys = Literal[
62+
"order", "write_empty_chunks", "read_missing_chunks", "codec_class_map", "codec_pipeline_class"
63+
]
64+
65+
ARRAY_CONFIG_PARAMS_KEYS: Final[set[str]] = {
66+
"order",
67+
"write_empty_chunks",
68+
"read_missing_chunks",
69+
"codec_class_map",
70+
"codec_pipeline_class",
71+
}
72+
ARRAY_CONFIG_PARAMS_KEYS_STATIC: Final[set[str]] = {
73+
"order",
74+
"write_empty_chunks",
75+
"read_missing_chunks",
76+
}
77+
"""The keys of the ArrayConfigParams object that are static and retrievable from the config"""
3378

3479

3580
@dataclass(frozen=True)
@@ -46,13 +91,14 @@ class ArrayConfig:
4691
read_missing_chunks : bool, default is True
4792
If True, missing chunks will be filled with the array's fill value on read.
4893
If False, reading missing chunks will raise a ``ChunkNotFoundError``.
49-
codec_classes : Mapping[str, object] | None, default is None
50-
A codec name : codec class mapping that defines the codec classes available
51-
for this array. Defaults to `None`, in which case a default collection of codecs
94+
codec_class_map : Mapping[str, object] | None, default is None
95+
A request for a codec name : codec class mapping that defines the codec classes available
96+
for array creation. Defaults to `None`, in which case a default collection of codecs
5297
is retrieved from the global config object.
53-
data_type_classes : set[ZDType] | None, default is None.
54-
A set of data type classes to use
55-
A data type identi
98+
codec_pipeline_class : CodecPipelineRequest | None, default = None
99+
A request for a codec pipeline class to be used for orchestrating chunk encoding and
100+
decoding. Defaults to `None`, in which case the default codec pipeline request
101+
is retrieved from information in the global config object.
56102
57103
Attributes
58104
----------
@@ -63,50 +109,62 @@ class ArrayConfig:
63109
read_missing_chunks : bool
64110
If True, missing chunks will be filled with the array's fill value on read.
65111
If False, reading missing chunks will raise a ``ChunkNotFoundError``.
66-
codec_classes : Mapping[str, object]
112+
codec_class_map : Mapping[str, object]
67113
A codec name : codec class mapping that defines the codec classes available
68-
for this array.
69-
data_type_clas
114+
for array creation.
115+
codec_pipeline_class : CodecPipelineRequest
116+
A request for a pipeline class that will be used for orchestrating chunk encoding and
117+
decoding.
70118
"""
71119

72120
order: MemoryOrder
73121
write_empty_chunks: bool
74122
read_missing_chunks: bool
75-
codec_classes: Mapping[str, object]
76-
data_type_classes: set[ZDType[Any, Any]]
77-
codec_pipeline_class: object
123+
codec_class_map: Mapping[str, type[Codec]]
124+
codec_pipeline_class: CodecPipelineRequest
78125

79126
def __init__(
80127
self,
81128
order: MemoryOrder,
82129
write_empty_chunks: bool,
83130
*,
84131
read_missing_chunks: bool = True,
85-
codec_class_map: Mapping[str, object] | None = None,
86-
codec_pipeline_class: object | None = None,
132+
codec_class_map: Mapping[str, type[ArrayBytesCodec | ArrayArrayCodec | BytesBytesCodec]]
133+
| None = None,
134+
codec_pipeline_class: CodecPipelineRequest | None = None,
87135
) -> None:
88136
order_parsed = parse_order(order)
89137
write_empty_chunks_parsed = parse_bool(write_empty_chunks)
90138
read_missing_chunks_parsed = parse_bool(read_missing_chunks)
139+
codec_class_map_parsed = parse_codec_class_map(codec_class_map)
140+
codec_pipeline_class_parsed = parse_codec_pipeline_class(codec_pipeline_class)
91141

92142
object.__setattr__(self, "order", order_parsed)
93143
object.__setattr__(self, "write_empty_chunks", write_empty_chunks_parsed)
94144
object.__setattr__(self, "read_missing_chunks", read_missing_chunks_parsed)
145+
object.__setattr__(self, "codec_class_map", codec_class_map_parsed)
146+
object.__setattr__(self, "codec_pipeline_class", codec_pipeline_class_parsed)
95147

96148
@classmethod
97-
def from_dict(cls, data: ArrayConfigParams) -> Self:
149+
def from_dict(cls, data: ArrayConfigRequest) -> Self:
98150
"""
99151
Create an ArrayConfig from a dict. The keys of that dict are a subset of the
100152
attributes of the ArrayConfig class. Any keys missing from that dict will be set to the
101153
the values in the ``array`` namespace of ``zarr.config``.
102154
"""
103-
kwargs_out: ArrayConfigParams = {}
155+
kwargs_out: ArrayConfigRequest = {}
104156
for f in fields(ArrayConfig):
105157
field_name = cast(
106-
"Literal['order', 'write_empty_chunks', 'read_missing_chunks']", f.name
158+
"Literal['order', 'write_empty_chunks', 'read_missing_chunks', 'codec_class_map', 'codec_pipeline_class']",
159+
f.name,
107160
)
108161
if field_name not in data:
109-
kwargs_out[field_name] = zarr_config.get(f"array.{field_name}")
162+
if field_name in ARRAY_CONFIG_PARAMS_KEYS_STATIC:
163+
kwargs_out[field_name] = zarr_config.get(f"array.{field_name}")
164+
elif field_name == "codec_class_map":
165+
kwargs_out["codec_class_map"] = parse_codec_class_map(None)
166+
elif field_name == "codec_pipeline_class":
167+
kwargs_out["codec_pipeline_class"] = parse_codec_pipeline_class(None)
110168
else:
111169
kwargs_out[field_name] = data[field_name]
112170
return cls(**kwargs_out)
@@ -119,10 +177,76 @@ def to_dict(self) -> ArrayConfigParams:
119177
"order": self.order,
120178
"write_empty_chunks": self.write_empty_chunks,
121179
"read_missing_chunks": self.read_missing_chunks,
180+
"codec_class_map": self.codec_class_map,
181+
"codec_pipeline_class": self.codec_pipeline_class,
122182
}
123183

124184

125-
ArrayConfigLike = ArrayConfig | ArrayConfigParams
185+
ArrayConfigLike = ArrayConfig | ArrayConfigRequest
186+
187+
188+
def _import_by_name(path: str) -> object | type:
189+
"""
190+
Import an object by its fully qualified name.
191+
"""
192+
import importlib
193+
194+
parts = path.split(".")
195+
196+
# Try progressively shorter module paths
197+
for i in range(len(parts), 0, -1):
198+
module_path = ".".join(parts[:i])
199+
try:
200+
module = importlib.import_module(module_path)
201+
break
202+
except ModuleNotFoundError:
203+
continue
204+
else:
205+
raise ImportError(f"Could not import any module from '{path}'")
206+
207+
obj = module
208+
for attr in parts[i:]:
209+
try:
210+
obj = getattr(obj, attr)
211+
except AttributeError as e:
212+
raise ImportError(f"Attribute '{attr}' not found in '{obj}'") from e
213+
return obj
214+
215+
216+
def parse_codec_pipeline_class(obj: CodecPipelineRequest | None) -> CodecPipelineRequest:
217+
if obj is None:
218+
config_entry: dict[str, str | int] = zarr_config.get("codec_pipeline")
219+
if "path" not in config_entry:
220+
msg = (
221+
"The codec_pipeline field in the global config is malformed. "
222+
"Expected 'path' key was not found."
223+
)
224+
raise KeyError(msg)
225+
else:
226+
path = config_entry["path"]
227+
options = {"batch_size": config_entry.get("batch_size", 1)}
228+
return {"class_path": path, "options": options}
229+
return obj
230+
231+
232+
def parse_codec_class_map(obj: Mapping[str, type[Codec]] | None) -> Mapping[str, type[Codec]]:
233+
"""
234+
Convert a request for a codec class map into an actual Mapping[str, type[Codec]].
235+
If the input is `None`, then we look up the list of codecs from the registry, where they
236+
are stored as fully qualified class names. We must resolve these names to concrete classes
237+
before inserting them into the returned mapping.
238+
"""
239+
if obj is None:
240+
name_map: dict[str, str] = zarr_config.get("codecs", {})
241+
out: dict[str, type[Codec]] = {}
242+
for key, value in name_map.items():
243+
maybe_cls = _import_by_name(value)
244+
if not issubclass(maybe_cls, Codec):
245+
msg = f"Expected a subclass of `Codec`, got {maybe_cls}"
246+
raise TypeError(msg)
247+
out[key] = maybe_cls
248+
return out
249+
return obj
126250

127251

128252
def parse_array_config(data: ArrayConfigLike | None) -> ArrayConfig:
@@ -137,25 +261,32 @@ def parse_array_config(data: ArrayConfigLike | None) -> ArrayConfig:
137261
return ArrayConfig.from_dict(data)
138262

139263

264+
@dataclass(frozen=True)
265+
class ArraySpecConfig:
266+
order: MemoryOrder
267+
write_empty_chunks: bool
268+
read_missing_chunks: bool = False
269+
270+
140271
@dataclass(frozen=True)
141272
class ArraySpec:
142273
shape: tuple[int, ...]
143274
dtype: ZDType[TBaseDType, TBaseScalar]
144275
fill_value: Any
145-
config: ArrayConfig
276+
config: ArraySpecConfig
146277
prototype: BufferPrototype
147278

148279
def __init__(
149280
self,
150281
shape: tuple[int, ...],
151282
dtype: ZDType[TBaseDType, TBaseScalar],
152283
fill_value: Any,
153-
config: ArrayConfig,
284+
config: ArraySpecConfig,
154285
prototype: BufferPrototype,
155286
) -> None:
156287
shape_parsed = parse_shapelike(shape)
157288
fill_value_parsed = parse_fill_value(fill_value)
158-
289+
assert isinstance(config, ArraySpecConfig)
159290
object.__setattr__(self, "shape", shape_parsed)
160291
object.__setattr__(self, "dtype", dtype)
161292
object.__setattr__(self, "fill_value", fill_value_parsed)

0 commit comments

Comments
 (0)