Skip to content

Commit 94c9a85

Browse files
committed
fix: can now declare config in open_array
1 parent 9269120 commit 94c9a85

7 files changed

Lines changed: 115 additions & 51 deletions

File tree

src/zarr/api/asynchronous.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,7 @@ async def open_array(
12421242
zarr_format: ZarrFormat | None = None,
12431243
path: PathLike = "",
12441244
storage_options: dict[str, Any] | None = None,
1245+
config: ArrayConfigLike | None = None,
12451246
**kwargs: Any, # TODO: type kwargs as valid args to save
12461247
) -> AnyAsyncArray:
12471248
"""Open an array using file-mode-like semantics.
@@ -1261,6 +1262,8 @@ async def open_array(
12611262
storage_options : dict
12621263
If using an fsspec URL to create the store, these will be passed to
12631264
the backend implementation. Ignored otherwise.
1265+
config : ArrayConfigLike
1266+
Declaration of the runtime configuration for the array.
12641267
**kwargs
12651268
Any keyword arguments to pass to [`create`][zarr.api.asynchronous.create].
12661269
@@ -1279,7 +1282,7 @@ async def open_array(
12791282
_warn_write_empty_chunks_kwarg()
12801283

12811284
try:
1282-
return await AsyncArray.open(store_path, zarr_format=zarr_format)
1285+
return await AsyncArray.open(store_path, zarr_format=zarr_format, config=config)
12831286
except FileNotFoundError as err:
12841287
if not store_path.read_only and mode in _CREATE_MODES:
12851288
overwrite = _infer_overwrite(mode)

src/zarr/api/synchronous.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,7 @@ def open_array(
13691369
zarr_format: ZarrFormat | None = None,
13701370
path: PathLike = "",
13711371
storage_options: dict[str, Any] | None = None,
1372+
config: ArrayConfigLike | None = None,
13721373
**kwargs: Any,
13731374
) -> AnyArray:
13741375
"""Open an array using file-mode-like semantics.
@@ -1388,6 +1389,8 @@ def open_array(
13881389
storage_options : dict
13891390
If using an fsspec URL to create the store, these will be passed to
13901391
the backend implementation. Ignored otherwise.
1392+
config : ArrayConfigLike
1393+
Declaration of the runtime configuration for the array.
13911394
**kwargs
13921395
Any keyword arguments to pass to [`create`][zarr.api.asynchronous.create].
13931396
@@ -1405,6 +1408,7 @@ def open_array(
14051408
zarr_format=zarr_format,
14061409
path=path,
14071410
storage_options=storage_options,
1411+
config=config,
14081412
**kwargs,
14091413
)
14101414
)

src/zarr/codecs/sharding.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,16 @@ def __getstate__(self) -> dict[str, Any]:
348348

349349
def __setstate__(self, state: dict[str, Any]) -> None:
350350
config = state["configuration"]
351+
codec_class_map = parse_codec_class_map(None)
351352
object.__setattr__(self, "chunk_shape", parse_shapelike(config["chunk_shape"]))
352-
object.__setattr__(self, "codecs", parse_codecs(config["codecs"]))
353-
object.__setattr__(self, "index_codecs", parse_codecs(config["index_codecs"]))
353+
object.__setattr__(
354+
self, "codecs", parse_codecs(config["codecs"], codec_class_map=codec_class_map)
355+
)
356+
object.__setattr__(
357+
self,
358+
"index_codecs",
359+
parse_codecs(config["index_codecs"], codec_class_map=codec_class_map),
360+
)
354361
object.__setattr__(self, "index_location", parse_index_location(config["index_location"]))
355362

356363
# Use instance-local lru_cache to avoid memory leaks

src/zarr/core/array.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,13 @@ def _chunk_sizes_from_shape(
203203
return tuple(result)
204204

205205

206-
def parse_array_metadata(data: object, config: ArrayConfig) -> ArrayMetadata:
207-
if isinstance(data, ArrayMetadata):
208-
return data.with_config(config)
206+
def parse_array_metadata(data: object, codec_class_map: Mapping[str, type[Codec]]) -> ArrayMetadata:
207+
if isinstance(data, ArrayV3Metadata):
208+
return type(data).from_dict(data.to_dict(), codec_class_map=codec_class_map)
209209
elif isinstance(data, dict):
210210
zarr_format = data.get("zarr_format")
211211
if zarr_format == 3:
212-
meta_out = ArrayV3Metadata.from_dict(data, config=config)
212+
meta_out = ArrayV3Metadata.from_dict(data, codec_class_map=codec_class_map)
213213
if len(meta_out.storage_transformers) > 0:
214214
msg = (
215215
f"Array metadata contains storage transformers: {meta_out.storage_transformers}."
@@ -218,26 +218,37 @@ def parse_array_metadata(data: object, config: ArrayConfig) -> ArrayMetadata:
218218
raise ValueError(msg)
219219
return meta_out
220220
elif zarr_format == 2:
221-
return ArrayV2Metadata.from_dict(data, config=config)
221+
return ArrayV2Metadata.from_dict(data)
222222
else:
223223
raise ValueError(f"Invalid zarr_format: {zarr_format}. Expected 2 or 3")
224224
raise TypeError # pragma: no cover
225225

226226

227-
def create_codec_pipeline(metadata: ArrayMetadata, *, store: Store | None = None) -> CodecPipeline:
227+
def create_codec_pipeline(
228+
metadata: ArrayMetadata,
229+
*,
230+
store: Store | None = None,
231+
config: ArrayConfig | None = None,
232+
) -> CodecPipeline:
233+
pipeline_class: type[CodecPipeline]
234+
if config is not None:
235+
pipeline_class = config.codec_pipeline_class
236+
else:
237+
pipeline_class = get_pipeline_class()
238+
228239
if store is not None:
229240
try:
230-
return get_pipeline_class().from_array_metadata_and_store(
241+
return pipeline_class.from_array_metadata_and_store(
231242
array_metadata=metadata, store=store
232243
)
233244
except NotImplementedError:
234245
pass
235246

236247
if isinstance(metadata, ArrayV3Metadata):
237-
return get_pipeline_class().from_codecs(metadata.codecs)
248+
return pipeline_class.from_codecs(metadata.codecs)
238249
elif isinstance(metadata, ArrayV2Metadata):
239250
v2_codec = V2Codec(filters=metadata.filters, compressor=metadata.compressor)
240-
return get_pipeline_class().from_codecs([v2_codec])
251+
return pipeline_class.from_codecs([v2_codec])
241252
raise TypeError # pragma: no cover
242253

243254

@@ -360,7 +371,9 @@ def __init__(
360371
config: ArrayConfigLike | None = None,
361372
) -> None:
362373
config_parsed = parse_array_config(config)
363-
metadata_parsed = parse_array_metadata(metadata, config=config_parsed)
374+
metadata_parsed = parse_array_metadata(
375+
metadata, codec_class_map=config_parsed.codec_class_map
376+
)
364377

365378
object.__setattr__(self, "metadata", metadata_parsed)
366379
object.__setattr__(self, "store_path", store_path)
@@ -369,7 +382,9 @@ def __init__(
369382
object.__setattr__(
370383
self,
371384
"codec_pipeline",
372-
create_codec_pipeline(metadata=metadata_parsed, store=store_path.store),
385+
create_codec_pipeline(
386+
metadata=metadata_parsed, store=store_path.store, config=config_parsed
387+
),
373388
)
374389

375390
# this overload defines the function signature when zarr_format is 2
@@ -785,6 +800,7 @@ def _create_metadata_v3(
785800
codecs: Iterable[Codec | dict[str, JSON]] | None = None,
786801
dimension_names: DimensionNamesLike = None,
787802
attributes: dict[str, JSON] | None = None,
803+
codec_class_map: Mapping[str, type[Codec]] | None = None,
788804
) -> ArrayV3Metadata:
789805
"""Create an instance of ArrayV3Metadata."""
790806
filters: tuple[ArrayArrayCodec, ...]
@@ -822,6 +838,7 @@ def _create_metadata_v3(
822838
codecs=codecs_parsed, # type: ignore[arg-type]
823839
dimension_names=tuple(dimension_names) if dimension_names else None,
824840
attributes=attributes or {},
841+
codec_class_map=codec_class_map,
825842
)
826843

827844
@classmethod
@@ -869,6 +886,7 @@ async def _create_v3(
869886
codecs=codecs,
870887
dimension_names=dimension_names,
871888
attributes=attributes,
889+
codec_class_map=config.codec_class_map,
872890
)
873891

874892
array = cls(metadata=metadata, store_path=store_path, config=config)
@@ -993,14 +1011,18 @@ def from_dict(
9931011
ValueError
9941012
If the dictionary data is invalid or incompatible with either Zarr format 2 or 3 array creation.
9951013
"""
996-
metadata = parse_array_metadata(data)
1014+
from zarr.core.array_spec import parse_codec_class_map
1015+
1016+
metadata = parse_array_metadata(data, codec_class_map=parse_codec_class_map(None))
9971017
return cls(metadata=metadata, store_path=store_path)
9981018

9991019
@classmethod
10001020
async def open(
10011021
cls,
10021022
store: StoreLike,
10031023
zarr_format: ZarrFormat | None = 3,
1024+
*,
1025+
config: ArrayConfigLike | None = None,
10041026
) -> AnyAsyncArray:
10051027
"""
10061028
Async method to open an existing Zarr array from a given store.
@@ -1013,6 +1035,8 @@ async def open(
10131035
for a description of all valid StoreLike values.
10141036
zarr_format : ZarrFormat | None, optional
10151037
The Zarr format version (default is 3).
1038+
config : ArrayConfigLike | None, (default is None)
1039+
Runtime configuration for the array.
10161040
10171041
Returns
10181042
-------
@@ -1044,7 +1068,7 @@ async def example():
10441068
metadata_dict = await get_array_metadata(store_path, zarr_format=zarr_format)
10451069
# TODO: remove this cast when we have better type hints
10461070
_metadata_dict = cast("ArrayMetadataJSON_V3", metadata_dict)
1047-
return cls(store_path=store_path, metadata=_metadata_dict)
1071+
return cls(store_path=store_path, metadata=_metadata_dict, config=config)
10481072

10491073
@property
10501074
def store(self) -> Store:
@@ -4710,7 +4734,7 @@ async def init_array(
47104734
chunk_key_encoding: ChunkKeyEncodingLike | None = None,
47114735
dimension_names: DimensionNamesLike = None,
47124736
overwrite: bool = False,
4713-
config: ArrayConfigLike | None = None,
4737+
config: ArrayConfig | None = None,
47144738
) -> AnyAsyncArray:
47154739
"""Create and persist an array metadata document.
47164740
@@ -4948,6 +4972,7 @@ async def init_array(
49484972
codecs=codecs_out,
49494973
dimension_names=dimension_names,
49504974
attributes=attributes,
4975+
codec_class_map=config.codec_class_map if config is not None else None,
49514976
)
49524977

49534978
arr = AsyncArray(metadata=meta, store_path=store_path, config=config)
@@ -5145,7 +5170,7 @@ async def create_array(
51455170
chunk_key_encoding=chunk_key_encoding,
51465171
dimension_names=dimension_names,
51475172
overwrite=overwrite,
5148-
config=config,
5173+
config=parse_array_config(config),
51495174
)
51505175

51515176

src/zarr/core/array_spec.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass, fields
44
from typing import TYPE_CHECKING, Any, Final, Literal, Self, TypedDict, cast
55

6-
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
6+
from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec, CodecPipeline
77
from zarr.core.common import (
88
MemoryOrder,
99
parse_bool,
@@ -30,19 +30,19 @@ class CodecPipelineRequest(TypedDict):
3030
options: NotRequired[dict[str, object]]
3131

3232

33-
class ArrayConfigParams(TypedDict):
33+
class ArrayConfigParams(TypedDict, closed=True): # type: ignore[call-arg]
3434
"""
3535
A TypedDict model of the attributes of an ArrayConfig class.
3636
"""
3737

3838
order: MemoryOrder
3939
write_empty_chunks: bool
4040
read_missing_chunks: bool
41-
codec_class_map: Mapping[str, object]
42-
codec_pipeline_class: CodecPipelineRequest
41+
codec_class_map: Mapping[str, type[ArrayArrayCodec | ArrayBytesCodec | BytesBytesCodec]]
42+
codec_pipeline_class: type[CodecPipeline]
4343

4444

45-
class ArrayConfigRequest(TypedDict):
45+
class ArrayConfigRequest(TypedDict, closed=True): # type: ignore[call-arg]
4646
"""
4747
A TypedDict model of the attributes of an ArrayConfig class, but with no required fields.
4848
This allows for partial construction of an ArrayConfig, with the assumption that the unset
@@ -55,7 +55,7 @@ class ArrayConfigRequest(TypedDict):
5555
codec_class_map: NotRequired[
5656
Mapping[str, type[ArrayArrayCodec | ArrayBytesCodec | BytesBytesCodec]]
5757
]
58-
codec_pipeline_class: NotRequired[CodecPipelineRequest]
58+
codec_pipeline_class: NotRequired[type[CodecPipeline]]
5959

6060

6161
ArrayConfigKeys = Literal[
@@ -112,16 +112,16 @@ class ArrayConfig:
112112
codec_class_map : Mapping[str, object]
113113
A codec name : codec class mapping that defines the codec classes available
114114
for array creation.
115-
codec_pipeline_class : CodecPipelineRequest
116-
A request for a pipeline class that will be used for orchestrating chunk encoding and
115+
codec_pipeline_class : type[CodecPipeline]
116+
A codec pipeline class that will be used for orchestrating chunk encoding and
117117
decoding.
118118
"""
119119

120120
order: MemoryOrder
121121
write_empty_chunks: bool
122122
read_missing_chunks: bool
123123
codec_class_map: Mapping[str, type[Codec]]
124-
codec_pipeline_class: CodecPipelineRequest
124+
codec_pipeline_class: type[CodecPipeline]
125125

126126
def __init__(
127127
self,
@@ -131,7 +131,7 @@ def __init__(
131131
read_missing_chunks: bool = True,
132132
codec_class_map: Mapping[str, type[ArrayBytesCodec | ArrayArrayCodec | BytesBytesCodec]]
133133
| None = None,
134-
codec_pipeline_class: CodecPipelineRequest | None = None,
134+
codec_pipeline_class: type[CodecPipeline] | None = None,
135135
) -> None:
136136
order_parsed = parse_order(order)
137137
write_empty_chunks_parsed = parse_bool(write_empty_chunks)
@@ -213,19 +213,16 @@ def _import_by_name(path: str) -> object | type:
213213
return obj
214214

215215

216-
def parse_codec_pipeline_class(obj: CodecPipelineRequest | None) -> CodecPipelineRequest:
216+
def parse_codec_pipeline_class(obj: type[CodecPipeline] | None) -> type[CodecPipeline]:
217217
if obj is None:
218-
config_entry: dict[str, str | int] = zarr_config.get("codec_pipeline")
218+
config_entry: dict[str, str] = zarr_config.get("codec_pipeline")
219219
if "path" not in config_entry:
220220
msg = (
221221
"The codec_pipeline field in the global config is malformed. "
222222
"Expected 'path' key was not found."
223223
)
224224
raise KeyError(msg)
225-
else:
226-
path = config_entry["path"]
227-
options = {"batch_size": config_entry.get("batch_size", 1)}
228-
return {"class_path": path, "options": options}
225+
return _import_by_name(config_entry["path"]) # type: ignore[return-value]
229226
return obj
230227

231228

@@ -241,6 +238,9 @@ def parse_codec_class_map(obj: Mapping[str, type[Codec]] | None) -> Mapping[str,
241238
out: dict[str, type[Codec]] = {}
242239
for key, value in name_map.items():
243240
maybe_cls = _import_by_name(value)
241+
if not isinstance(maybe_cls, type):
242+
msg = f"Expected a type, got {maybe_cls}"
243+
raise TypeError(msg)
244244
if not issubclass(maybe_cls, Codec):
245245
msg = f"Expected a subclass of `Codec`, got {maybe_cls}"
246246
raise TypeError(msg)

0 commit comments

Comments
 (0)