Skip to content

Commit 7f758b3

Browse files
committed
define a module for chunk grids, and a registry
1 parent d926e43 commit 7f758b3

5 files changed

Lines changed: 320 additions & 65 deletions

File tree

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from zarr.core.chunk_grids.common import (
6+
ChunkGrid,
7+
_auto_partition,
8+
_guess_chunks,
9+
_guess_num_chunks_per_axis_shard,
10+
normalize_chunks,
11+
)
12+
from zarr.core.chunk_grids.regular import RegularChunkGrid
13+
from zarr.core.common import JSON, NamedConfig, parse_named_configuration
14+
from zarr.registry import get_chunk_grid_class, register_chunk_grid
15+
16+
register_chunk_grid("regular", RegularChunkGrid)
17+
18+
19+
def parse_chunk_grid(
20+
data: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any],
21+
) -> ChunkGrid:
22+
"""Parse a chunk grid from a dictionary, returning existing ChunkGrid instances as-is.
23+
24+
Uses the chunk grid registry to look up the appropriate class by name.
25+
26+
Parameters
27+
----------
28+
data : dict[str, JSON] | ChunkGrid | NamedConfig[str, Any]
29+
Either a ChunkGrid instance (returned as-is) or a dictionary with
30+
'name' and 'configuration' keys.
31+
32+
Returns
33+
-------
34+
ChunkGrid
35+
36+
Raises
37+
------
38+
ValueError
39+
If the chunk grid name is not found in the registry.
40+
"""
41+
if isinstance(data, ChunkGrid):
42+
return data
43+
44+
name_parsed, _ = parse_named_configuration(data)
45+
try:
46+
chunk_grid_cls = get_chunk_grid_class(name_parsed)
47+
except KeyError as e:
48+
raise ValueError(f"Unknown chunk grid. Got {name_parsed}.") from e
49+
return chunk_grid_cls.from_dict(data) # type: ignore[arg-type]
50+
51+
52+
__all__ = [
53+
"ChunkGrid",
54+
"RegularChunkGrid",
55+
"_auto_partition",
56+
"_guess_chunks",
57+
"_guess_num_chunks_per_axis_shard",
58+
"normalize_chunks",
59+
"parse_chunk_grid",
60+
]
Lines changed: 136 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,159 @@
11
from __future__ import annotations
22

3-
import itertools
43
import math
54
import numbers
6-
import operator
75
import warnings
86
from abc import abstractmethod
97
from dataclasses import dataclass
10-
from functools import reduce
118
from typing import TYPE_CHECKING, Any, Literal
129

1310
import numpy as np
11+
import numpy.typing as npt
1412

1513
import zarr
1614
from zarr.abc.metadata import Metadata
17-
from zarr.core.common import (
18-
JSON,
19-
NamedConfig,
20-
ShapeLike,
21-
ceildiv,
22-
parse_named_configuration,
23-
parse_shapelike,
24-
)
2515
from zarr.errors import ZarrUserWarning
2616

2717
if TYPE_CHECKING:
2818
from collections.abc import Iterator
2919
from typing import Self
3020

3121
from zarr.core.array import ShardsLike
22+
from zarr.core.common import JSON
23+
24+
25+
@dataclass(frozen=True)
26+
class ChunkGrid(Metadata):
27+
@abstractmethod
28+
def to_dict(self) -> dict[str, JSON]: ...
29+
30+
@abstractmethod
31+
def update_shape(self, new_shape: tuple[int, ...]) -> Self:
32+
pass
33+
34+
@abstractmethod
35+
def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
36+
pass
37+
38+
@abstractmethod
39+
def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
40+
pass
41+
42+
@abstractmethod
43+
def get_chunk_shape(
44+
self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...]
45+
) -> tuple[int, ...]:
46+
"""
47+
Get the shape of a specific chunk.
48+
49+
Parameters
50+
----------
51+
array_shape : tuple[int, ...]
52+
Shape of the full array.
53+
chunk_coord : tuple[int, ...]
54+
Coordinates of the chunk in the chunk grid.
55+
56+
Returns
57+
-------
58+
tuple[int, ...]
59+
Shape of the chunk at the given coordinates.
60+
"""
61+
62+
@abstractmethod
63+
def get_chunk_start(
64+
self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...]
65+
) -> tuple[int, ...]:
66+
"""
67+
Get the starting position of a chunk in the array.
68+
69+
Parameters
70+
----------
71+
array_shape : tuple[int, ...]
72+
Shape of the full array.
73+
chunk_coord : tuple[int, ...]
74+
Coordinates of the chunk in the chunk grid.
75+
76+
Returns
77+
-------
78+
tuple[int, ...]
79+
Starting position (offset) of the chunk in the array.
80+
"""
81+
82+
@abstractmethod
83+
def array_index_to_chunk_coord(
84+
self, array_shape: tuple[int, ...], array_index: tuple[int, ...]
85+
) -> tuple[int, ...]:
86+
"""
87+
Map an array index to the chunk coordinates that contain it.
88+
89+
Parameters
90+
----------
91+
array_shape : tuple[int, ...]
92+
Shape of the full array.
93+
array_index : tuple[int, ...]
94+
Index in the array.
95+
96+
Returns
97+
-------
98+
tuple[int, ...]
99+
Coordinates of the chunk containing the array index.
100+
"""
101+
102+
@abstractmethod
103+
def array_indices_to_chunk_dim(
104+
self, array_shape: tuple[int, ...], dim: int, indices: npt.NDArray[np.intp]
105+
) -> npt.NDArray[np.intp]:
106+
"""
107+
Map an array of indices along one dimension to chunk coordinates (vectorized).
108+
109+
Parameters
110+
----------
111+
array_shape : tuple[int, ...]
112+
Shape of the full array.
113+
dim : int
114+
Dimension index.
115+
indices : np.ndarray
116+
Array of indices along the given dimension.
117+
118+
Returns
119+
-------
120+
np.ndarray
121+
Array of chunk coordinates, same shape as indices.
122+
"""
123+
124+
@abstractmethod
125+
def chunks_per_dim(self, array_shape: tuple[int, ...], dim: int) -> int:
126+
"""
127+
Get the number of chunks along a specific dimension.
128+
129+
Parameters
130+
----------
131+
array_shape : tuple[int, ...]
132+
Shape of the full array.
133+
dim : int
134+
Dimension index.
135+
136+
Returns
137+
-------
138+
int
139+
Number of chunks along the dimension.
140+
"""
141+
142+
@abstractmethod
143+
def get_chunk_grid_shape(self, array_shape: tuple[int, ...]) -> tuple[int, ...]:
144+
"""
145+
Get the shape of the chunk grid (number of chunks along each dimension).
146+
147+
Parameters
148+
----------
149+
array_shape : tuple[int, ...]
150+
Shape of the full array.
151+
152+
Returns
153+
-------
154+
tuple[int, ...]
155+
Number of chunks along each dimension.
156+
"""
32157

33158

34159
def _guess_chunks(
@@ -153,58 +278,6 @@ def normalize_chunks(chunks: Any, shape: tuple[int, ...], typesize: int) -> tupl
153278
return tuple(int(c) for c in chunks)
154279

155280

156-
@dataclass(frozen=True)
157-
class ChunkGrid(Metadata):
158-
@classmethod
159-
def from_dict(cls, data: dict[str, JSON] | ChunkGrid | NamedConfig[str, Any]) -> ChunkGrid:
160-
if isinstance(data, ChunkGrid):
161-
return data
162-
163-
name_parsed, _ = parse_named_configuration(data)
164-
if name_parsed == "regular":
165-
return RegularChunkGrid._from_dict(data)
166-
raise ValueError(f"Unknown chunk grid. Got {name_parsed}.")
167-
168-
@abstractmethod
169-
def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
170-
pass
171-
172-
@abstractmethod
173-
def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
174-
pass
175-
176-
177-
@dataclass(frozen=True)
178-
class RegularChunkGrid(ChunkGrid):
179-
chunk_shape: tuple[int, ...]
180-
181-
def __init__(self, *, chunk_shape: ShapeLike) -> None:
182-
chunk_shape_parsed = parse_shapelike(chunk_shape)
183-
184-
object.__setattr__(self, "chunk_shape", chunk_shape_parsed)
185-
186-
@classmethod
187-
def _from_dict(cls, data: dict[str, JSON] | NamedConfig[str, Any]) -> Self:
188-
_, configuration_parsed = parse_named_configuration(data, "regular")
189-
190-
return cls(**configuration_parsed) # type: ignore[arg-type]
191-
192-
def to_dict(self) -> dict[str, JSON]:
193-
return {"name": "regular", "configuration": {"chunk_shape": tuple(self.chunk_shape)}}
194-
195-
def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
196-
return itertools.product(
197-
*(range(ceildiv(s, c)) for s, c in zip(array_shape, self.chunk_shape, strict=False))
198-
)
199-
200-
def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
201-
return reduce(
202-
operator.mul,
203-
itertools.starmap(ceildiv, zip(array_shape, self.chunk_shape, strict=True)),
204-
1,
205-
)
206-
207-
208281
def _guess_num_chunks_per_axis_shard(
209282
chunk_shape: tuple[int, ...], item_size: int, max_bytes: int, array_shape: tuple[int, ...]
210283
) -> int:
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from __future__ import annotations
2+
3+
import itertools
4+
import operator
5+
from dataclasses import dataclass
6+
from functools import reduce
7+
from typing import TYPE_CHECKING, Any
8+
9+
import numpy as np
10+
import numpy.typing as npt
11+
12+
from zarr.core.chunk_grids.common import ChunkGrid
13+
from zarr.core.common import (
14+
JSON,
15+
NamedConfig,
16+
ShapeLike,
17+
ceildiv,
18+
parse_named_configuration,
19+
parse_shapelike,
20+
)
21+
22+
if TYPE_CHECKING:
23+
from collections.abc import Iterator
24+
from typing import Self
25+
26+
27+
@dataclass(frozen=True)
28+
class RegularChunkGrid(ChunkGrid):
29+
chunk_shape: tuple[int, ...]
30+
31+
def __init__(self, *, chunk_shape: ShapeLike) -> None:
32+
chunk_shape_parsed = parse_shapelike(chunk_shape)
33+
34+
object.__setattr__(self, "chunk_shape", chunk_shape_parsed)
35+
36+
@classmethod
37+
def from_dict(cls, data: dict[str, JSON] | NamedConfig[str, Any]) -> Self:
38+
_, configuration_parsed = parse_named_configuration(data, "regular")
39+
40+
return cls(**configuration_parsed) # type: ignore[arg-type]
41+
42+
def to_dict(self) -> dict[str, JSON]:
43+
return {"name": "regular", "configuration": {"chunk_shape": tuple(self.chunk_shape)}}
44+
45+
def update_shape(self, new_shape: tuple[int, ...]) -> Self:
46+
return self
47+
48+
def all_chunk_coords(self, array_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
49+
return itertools.product(
50+
*(range(ceildiv(s, c)) for s, c in zip(array_shape, self.chunk_shape, strict=False))
51+
)
52+
53+
def get_nchunks(self, array_shape: tuple[int, ...]) -> int:
54+
return reduce(
55+
operator.mul,
56+
itertools.starmap(ceildiv, zip(array_shape, self.chunk_shape, strict=True)),
57+
1,
58+
)
59+
60+
def get_chunk_shape(
61+
self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...]
62+
) -> tuple[int, ...]:
63+
return tuple(
64+
int(min(self.chunk_shape[i], array_shape[i] - chunk_coord[i] * self.chunk_shape[i]))
65+
for i in range(len(array_shape))
66+
)
67+
68+
def get_chunk_start(
69+
self, array_shape: tuple[int, ...], chunk_coord: tuple[int, ...]
70+
) -> tuple[int, ...]:
71+
return tuple(
72+
coord * size for coord, size in zip(chunk_coord, self.chunk_shape, strict=False)
73+
)
74+
75+
def array_index_to_chunk_coord(
76+
self, array_shape: tuple[int, ...], array_index: tuple[int, ...]
77+
) -> tuple[int, ...]:
78+
return tuple(
79+
0 if size == 0 else idx // size
80+
for idx, size in zip(array_index, self.chunk_shape, strict=False)
81+
)
82+
83+
def array_indices_to_chunk_dim(
84+
self, array_shape: tuple[int, ...], dim: int, indices: npt.NDArray[np.intp]
85+
) -> npt.NDArray[np.intp]:
86+
chunk_size = self.chunk_shape[dim]
87+
if chunk_size == 0:
88+
return np.zeros_like(indices)
89+
return indices // chunk_size
90+
91+
def chunks_per_dim(self, array_shape: tuple[int, ...], dim: int) -> int:
92+
return ceildiv(array_shape[dim], self.chunk_shape[dim])
93+
94+
def get_chunk_grid_shape(self, array_shape: tuple[int, ...]) -> tuple[int, ...]:
95+
return tuple(
96+
ceildiv(array_len, chunk_len)
97+
for array_len, chunk_len in zip(array_shape, self.chunk_shape, strict=False)
98+
)

0 commit comments

Comments
 (0)