Skip to content

Commit faa5ef5

Browse files
d-v-bclaude
andcommitted
feat: add IndexTransform library for composable, lazy coordinate mappings
Add a new `src/zarr/core/transforms/` package implementing TensorStore-inspired index transforms. The core idea: every indexing operation (slicing, fancy indexing, etc.) produces a coordinate mapping from user space to storage space. These mappings compose lazily — no I/O until explicitly resolved. Key types: - `IndexDomain` — rectangular region in N-dimensional integer space - `ConstantMap`, `DimensionMap`, `ArrayMap` — three representations of a set of storage coordinates (singleton, arithmetic progression, explicit enumeration) - `IndexTransform` — pairs an input domain with output maps (one per storage dim) - `compose(outer, inner)` — chain two transforms Key operations on IndexTransform: - `__getitem__`, `.oindex[]`, `.vindex[]` — indexing produces new transforms - `.intersect(domain)` — restrict to coordinates within a region (chunk resolution) - `.translate(shift)` — shift coordinates (make chunk-local) The transform library is standalone with no dependency on Array. Includes comprehensive test suite (143 tests covering all types, operations, composition, chunk resolution, and edge cases). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e6207b7 commit faa5ef5

15 files changed

Lines changed: 3369 additions & 60 deletions

src/zarr/core/array.py

Lines changed: 542 additions & 59 deletions
Large diffs are not rendered by default.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Composable, lazy coordinate transforms for zarr array indexing.
2+
3+
This package implements TensorStore-inspired index transforms. The core idea:
4+
every indexing operation (slicing, fancy indexing, etc.) produces a coordinate
5+
mapping from user space to storage space. These mappings compose lazily — no
6+
I/O until you explicitly read or write.
7+
8+
Key types:
9+
10+
- ``IndexDomain`` — a rectangular region of integer coordinates
11+
- ``IndexTransform`` — maps input coordinates to storage coordinates
12+
- ``ConstantMap``, ``DimensionMap``, ``ArrayMap`` — the three ways a single
13+
output dimension can depend on the input (see ``output_map.py``)
14+
- ``compose`` — chain two transforms into one
15+
"""
16+
17+
from zarr.core.transforms.composition import compose
18+
from zarr.core.transforms.domain import IndexDomain
19+
from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap
20+
from zarr.core.transforms.transform import IndexTransform
21+
22+
__all__ = [
23+
"ArrayMap",
24+
"ConstantMap",
25+
"DimensionMap",
26+
"IndexDomain",
27+
"IndexTransform",
28+
"OutputIndexMap",
29+
"compose",
30+
]
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""Chunk resolution — mapping transforms to chunk-level I/O.
2+
3+
Given an ``IndexTransform`` (which coordinates a user wants to access) and a
4+
``ChunkGrid`` (how storage is divided into chunks), chunk resolution answers:
5+
6+
For each chunk, which storage coordinates does this transform touch,
7+
and where do those values land in the output buffer?
8+
9+
The algorithm is:
10+
11+
1. **Enumerate candidate chunks** — determine which chunks could possibly
12+
be touched by the transform's output coordinate ranges.
13+
14+
2. **Intersect** — for each candidate chunk, call
15+
``transform.intersect(chunk_domain)`` to restrict the transform to
16+
coordinates within that chunk. If the intersection is empty, skip it.
17+
18+
3. **Translate** — shift the restricted transform to chunk-local coordinates
19+
via ``transform.translate(-chunk_origin)``.
20+
21+
4. **Yield** — produce ``(chunk_coords, local_transform, surviving_indices)``
22+
triples that the codec pipeline consumes.
23+
24+
``sub_transform_to_selections`` bridges from the transform representation
25+
back to the raw ``(chunk_selection, out_selection, drop_axes)`` tuples that
26+
the current codec pipeline expects. This bridge will go away when the codec
27+
pipeline accepts transforms natively.
28+
"""
29+
30+
from __future__ import annotations
31+
32+
from typing import TYPE_CHECKING, Any
33+
34+
import numpy as np
35+
36+
from zarr.core.transforms.domain import IndexDomain
37+
from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap
38+
from zarr.core.transforms.transform import IndexTransform
39+
40+
if TYPE_CHECKING:
41+
from collections.abc import Iterator
42+
43+
from zarr.core.chunk_grids import ChunkGrid
44+
45+
ChunkTransformResult = tuple[
46+
tuple[int, ...],
47+
IndexTransform,
48+
np.ndarray[Any, np.dtype[np.intp]] | None,
49+
]
50+
51+
52+
def iter_chunk_transforms(
53+
transform: IndexTransform,
54+
chunk_grid: ChunkGrid,
55+
) -> Iterator[ChunkTransformResult]:
56+
"""Resolve a composed IndexTransform against a ChunkGrid.
57+
58+
Yields ``(chunk_coords, sub_transform, out_indices)`` triples:
59+
60+
- ``chunk_coords``: which chunk to access.
61+
- ``sub_transform``: maps output buffer coords to chunk-local coords.
62+
- ``out_indices``: for vectorized/array indexing, the output scatter
63+
indices (integer array). ``None`` for basic/slice indexing.
64+
"""
65+
dim_grids = chunk_grid._dimensions
66+
67+
# Enumerate all possible chunks via cartesian product of per-dim chunk ranges
68+
# For each candidate chunk, intersect the transform with the chunk domain.
69+
# The transform.intersect method handles both orthogonal and vectorized cases.
70+
chunk_ranges: list[range] = []
71+
for out_dim, m in enumerate(transform.output):
72+
dg = dim_grids[out_dim]
73+
if isinstance(m, ConstantMap):
74+
# Single chunk
75+
c = dg.index_to_chunk(m.offset)
76+
chunk_ranges.append(range(c, c + 1))
77+
elif isinstance(m, DimensionMap):
78+
d = m.input_dimension
79+
dim_lo = transform.domain.inclusive_min[d]
80+
dim_hi = transform.domain.exclusive_max[d]
81+
if dim_lo >= dim_hi:
82+
return # empty domain
83+
if m.stride > 0:
84+
s_min = m.offset + m.stride * dim_lo
85+
s_max = m.offset + m.stride * (dim_hi - 1)
86+
else:
87+
s_min = m.offset + m.stride * (dim_hi - 1)
88+
s_max = m.offset + m.stride * dim_lo
89+
first = dg.index_to_chunk(s_min)
90+
last = dg.index_to_chunk(s_max)
91+
chunk_ranges.append(range(first, last + 1))
92+
elif isinstance(m, ArrayMap):
93+
storage = m.offset + m.stride * m.index_array
94+
flat = storage.ravel().astype(np.intp)
95+
chunk_ids = dg.indices_to_chunks(flat)
96+
first = int(chunk_ids.min())
97+
last = int(chunk_ids.max())
98+
chunk_ranges.append(range(first, last + 1))
99+
100+
import itertools
101+
102+
for chunk_coords_tuple in itertools.product(*chunk_ranges):
103+
chunk_coords = tuple(int(c) for c in chunk_coords_tuple)
104+
105+
# Build the chunk domain in storage space
106+
chunk_min: list[int] = []
107+
chunk_max: list[int] = []
108+
chunk_shift: list[int] = []
109+
for out_dim, c in enumerate(chunk_coords):
110+
dg = dim_grids[out_dim]
111+
c_start = dg.chunk_offset(c)
112+
c_size = dg.chunk_size(c)
113+
chunk_min.append(c_start)
114+
chunk_max.append(c_start + c_size)
115+
chunk_shift.append(-c_start)
116+
117+
chunk_domain = IndexDomain(
118+
inclusive_min=tuple(chunk_min),
119+
exclusive_max=tuple(chunk_max),
120+
)
121+
122+
# Intersect transform with chunk domain
123+
result = transform.intersect(chunk_domain)
124+
if result is None:
125+
continue
126+
127+
restricted, surviving = result
128+
129+
# Translate to chunk-local coordinates
130+
local = restricted.translate(tuple(chunk_shift))
131+
132+
yield (chunk_coords, local, surviving)
133+
134+
135+
def sub_transform_to_selections(
136+
sub_transform: IndexTransform,
137+
out_indices: np.ndarray[Any, np.dtype[np.intp]] | None = None,
138+
) -> tuple[
139+
tuple[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...],
140+
tuple[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...],
141+
tuple[int, ...],
142+
]:
143+
"""Convert a chunk-local sub-transform to raw selections for the codec pipeline.
144+
145+
Parameters
146+
----------
147+
sub_transform
148+
A chunk-local IndexTransform (output maps already translated to
149+
chunk-local coordinates).
150+
out_indices
151+
For vectorized indexing: the output scatter indices for this chunk.
152+
None for orthogonal/basic indexing.
153+
154+
Returns
155+
-------
156+
tuple
157+
``(chunk_selection, out_selection, drop_axes)``
158+
"""
159+
chunk_sel: list[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = []
160+
drop_axes: list[int] = []
161+
162+
for m in sub_transform.output:
163+
if isinstance(m, ConstantMap):
164+
chunk_sel.append(m.offset)
165+
elif isinstance(m, DimensionMap):
166+
dim_lo = sub_transform.domain.inclusive_min[m.input_dimension]
167+
dim_hi = sub_transform.domain.exclusive_max[m.input_dimension]
168+
start = m.offset + m.stride * dim_lo
169+
stop = m.offset + m.stride * dim_hi
170+
if m.stride < 0:
171+
start, stop = stop + 1, start + 1
172+
chunk_sel.append(slice(start, stop, m.stride))
173+
elif isinstance(m, ArrayMap):
174+
if m.offset == 0 and m.stride == 1:
175+
chunk_sel.append(m.index_array)
176+
else:
177+
storage_coords = m.offset + m.stride * m.index_array
178+
chunk_sel.append(storage_coords.astype(np.intp))
179+
180+
# Build out_sel: one entry per non-dropped output dim.
181+
out_sel: list[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = []
182+
183+
# Vectorized: multiple correlated ArrayMaps share one scatter index
184+
is_vectorized = (
185+
out_indices is not None
186+
and sum(1 for m in sub_transform.output if isinstance(m, ArrayMap)) >= 2
187+
)
188+
189+
if is_vectorized:
190+
assert out_indices is not None
191+
out_sel.append(out_indices)
192+
else:
193+
for m in sub_transform.output:
194+
if isinstance(m, ConstantMap):
195+
continue
196+
if isinstance(m, DimensionMap):
197+
lo = sub_transform.domain.inclusive_min[m.input_dimension]
198+
hi = sub_transform.domain.exclusive_max[m.input_dimension]
199+
out_sel.append(slice(lo, hi))
200+
elif isinstance(m, ArrayMap):
201+
if out_indices is not None:
202+
# Orthogonal ArrayMap: out_indices has the surviving positions
203+
out_sel.append(out_indices)
204+
else:
205+
out_sel.append(slice(0, len(m.index_array)))
206+
207+
return tuple(chunk_sel), tuple(out_sel), tuple(drop_axes)
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
5+
from zarr.core.transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap
6+
from zarr.core.transforms.transform import IndexTransform
7+
8+
9+
def compose(outer: IndexTransform, inner: IndexTransform) -> IndexTransform:
10+
"""Compose two IndexTransforms.
11+
12+
``outer`` maps user coords (rank m) to intermediate coords (rank n).
13+
``inner`` maps intermediate coords (rank n) to storage coords (rank p).
14+
The result maps user coords (rank m) to storage coords (rank p).
15+
16+
Precondition: ``outer.output_rank == inner.domain.ndim``.
17+
"""
18+
if outer.output_rank != inner.domain.ndim:
19+
raise ValueError(
20+
f"outer output rank ({outer.output_rank}) must match inner input rank "
21+
f"({inner.domain.ndim})"
22+
)
23+
24+
result_output = [_compose_single(outer, inner_map) for inner_map in inner.output]
25+
26+
return IndexTransform(domain=outer.domain, output=tuple(result_output))
27+
28+
29+
def _compose_single(outer: IndexTransform, inner_map: OutputIndexMap) -> OutputIndexMap:
30+
"""Compose a single inner output map with the full outer transform."""
31+
if isinstance(inner_map, ConstantMap):
32+
return ConstantMap(offset=inner_map.offset)
33+
34+
if isinstance(inner_map, DimensionMap):
35+
return _compose_dimension(outer, inner_map)
36+
37+
if isinstance(inner_map, ArrayMap):
38+
return _compose_array(outer, inner_map)
39+
40+
raise TypeError(f"Unknown output map type: {type(inner_map)}") # pragma: no cover
41+
42+
43+
def _compose_dimension(outer: IndexTransform, inner_map: DimensionMap) -> OutputIndexMap:
44+
"""Compose when inner is a DimensionMap.
45+
46+
storage = offset_i + stride_i * intermediate[dim_i]
47+
where intermediate[dim_i] = outer.output[dim_i](user_input)
48+
"""
49+
dim_i = inner_map.input_dimension
50+
offset_i = inner_map.offset
51+
stride_i = inner_map.stride
52+
outer_map = outer.output[dim_i]
53+
54+
if isinstance(outer_map, ConstantMap):
55+
return ConstantMap(offset=offset_i + stride_i * outer_map.offset)
56+
57+
if isinstance(outer_map, DimensionMap):
58+
return DimensionMap(
59+
input_dimension=outer_map.input_dimension,
60+
offset=offset_i + stride_i * outer_map.offset,
61+
stride=stride_i * outer_map.stride,
62+
)
63+
64+
if isinstance(outer_map, ArrayMap):
65+
return ArrayMap(
66+
index_array=outer_map.index_array,
67+
offset=offset_i + stride_i * outer_map.offset,
68+
stride=stride_i * outer_map.stride,
69+
)
70+
71+
raise TypeError(f"Unknown output map type: {type(outer_map)}") # pragma: no cover
72+
73+
74+
def _compose_array(outer: IndexTransform, inner_map: ArrayMap) -> OutputIndexMap:
75+
"""Compose when inner is an ArrayMap.
76+
77+
storage = offset_i + stride_i * arr_i[intermediate]
78+
We need to evaluate arr_i at the intermediate coordinates produced by outer.
79+
"""
80+
arr_i = inner_map.index_array
81+
offset_i = inner_map.offset
82+
stride_i = inner_map.stride
83+
84+
# Check if all outer outputs are constant
85+
all_constant = all(isinstance(m, ConstantMap) for m in outer.output)
86+
87+
if all_constant:
88+
# Evaluate arr_i at the single constant point
89+
idx = tuple(m.offset for m in outer.output if isinstance(m, ConstantMap))
90+
value = int(arr_i[idx])
91+
return ConstantMap(offset=offset_i + stride_i * value)
92+
93+
# For 1D inner array with a single outer output (simple case)
94+
if arr_i.ndim == 1 and len(outer.output) == 1:
95+
outer_map = outer.output[0]
96+
97+
if isinstance(outer_map, DimensionMap):
98+
dim_size = outer.domain.shape[outer_map.input_dimension]
99+
user_indices = np.arange(dim_size, dtype=np.intp)
100+
intermediate_vals = outer_map.offset + outer_map.stride * user_indices
101+
new_arr = arr_i[intermediate_vals]
102+
return ArrayMap(index_array=new_arr, offset=offset_i, stride=stride_i)
103+
104+
if isinstance(outer_map, ArrayMap):
105+
intermediate_vals = outer_map.offset + outer_map.stride * outer_map.index_array
106+
new_arr = arr_i[intermediate_vals]
107+
return ArrayMap(index_array=new_arr, offset=offset_i, stride=stride_i)
108+
109+
# General multi-dim case: not yet implemented
110+
raise NotImplementedError(
111+
"Composing a multi-dimensional inner array map with non-constant outer maps "
112+
"is not yet supported."
113+
)

0 commit comments

Comments
 (0)