Skip to content

Commit 71a780b

Browse files
committed
chunktransform
1 parent 4e262b1 commit 71a780b

File tree

2 files changed

+75
-38
lines changed

2 files changed

+75
-38
lines changed

src/zarr/core/codec_pipeline.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -69,49 +69,55 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any:
6969
return fill_value
7070

7171

72-
@dataclass(frozen=True, slots=True)
73-
class CodecChain:
74-
"""Codec chain with pre-resolved metadata specs.
72+
@dataclass(slots=True, kw_only=True)
73+
class ChunkTransform:
74+
"""A stored chunk, modeled as a layered array.
7575
76-
Constructed from an iterable of codecs and a chunk ArraySpec.
77-
Resolves each codec against the spec so that encode/decode can
78-
run without re-resolving.
76+
Each layer corresponds to one ArrayArrayCodec and the ArraySpec
77+
at its input boundary. ``layers[0]`` is the outermost (user-visible)
78+
transform; after the last layer comes the ArrayBytesCodec.
79+
80+
The chunk's ``shape`` and ``dtype`` reflect the representation
81+
**after** all ArrayArrayCodec layers have been applied — i.e. the
82+
spec that feeds the ArrayBytesCodec.
7983
"""
8084

8185
codecs: tuple[Codec, ...]
82-
chunk_spec: ArraySpec
86+
array_spec: ArraySpec
8387

84-
_aa_codecs: tuple[ArrayArrayCodec, ...] = field(init=False, repr=False, compare=False)
85-
_aa_specs: tuple[ArraySpec, ...] = field(init=False, repr=False, compare=False)
88+
# Each element is (ArrayArrayCodec, input_spec_for_that_codec).
89+
layers: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = field(
90+
init=False, repr=False, compare=False
91+
)
8692
_ab_codec: ArrayBytesCodec = field(init=False, repr=False, compare=False)
8793
_ab_spec: ArraySpec = field(init=False, repr=False, compare=False)
8894
_bb_codecs: tuple[BytesBytesCodec, ...] = field(init=False, repr=False, compare=False)
89-
_bb_spec: ArraySpec = field(init=False, repr=False, compare=False)
9095
_all_sync: bool = field(init=False, repr=False, compare=False)
9196

9297
def __post_init__(self) -> None:
9398
aa, ab, bb = codecs_from_list(list(self.codecs))
9499

95-
aa_specs: list[ArraySpec] = []
96-
spec = self.chunk_spec
100+
layers: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = ()
101+
spec = self.array_spec
97102
for aa_codec in aa:
98-
aa_specs.append(spec)
103+
layers = (*layers, (aa_codec, spec))
99104
spec = aa_codec.resolve_metadata(spec)
100105

101-
object.__setattr__(self, "_aa_codecs", aa)
102-
object.__setattr__(self, "_aa_specs", tuple(aa_specs))
103-
object.__setattr__(self, "_ab_codec", ab)
104-
object.__setattr__(self, "_ab_spec", spec)
106+
self.layers = layers
107+
self._ab_codec = ab
108+
self._ab_spec = spec
109+
self._bb_codecs = bb
110+
self._all_sync = all(isinstance(c, SupportsSyncCodec) for c in self.codecs)
105111

106-
spec = ab.resolve_metadata(spec)
107-
object.__setattr__(self, "_bb_codecs", bb)
108-
object.__setattr__(self, "_bb_spec", spec)
112+
@property
113+
def shape(self) -> tuple[int, ...]:
114+
"""Shape after all ArrayArrayCodec layers (input to the ArrayBytesCodec)."""
115+
return self._ab_spec.shape
109116

110-
object.__setattr__(
111-
self,
112-
"_all_sync",
113-
all(isinstance(c, SupportsSyncCodec) for c in self.codecs),
114-
)
117+
@property
118+
def dtype(self) -> ZDType[TBaseDType, TBaseScalar]:
119+
"""Dtype after all ArrayArrayCodec layers (input to the ArrayBytesCodec)."""
120+
return self._ab_spec.dtype
115121

116122
@property
117123
def all_sync(self) -> bool:
@@ -127,11 +133,11 @@ def decode_chunk(
127133
"""
128134
bb_out: Any = chunk_bytes
129135
for bb_codec in reversed(self._bb_codecs):
130-
bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, self._bb_spec)
136+
bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, self._ab_spec)
131137

132138
ab_out: Any = cast("SupportsSyncCodec", self._ab_codec)._decode_sync(bb_out, self._ab_spec)
133139

134-
for aa_codec, spec in zip(reversed(self._aa_codecs), reversed(self._aa_specs), strict=True):
140+
for aa_codec, spec in reversed(self.layers):
135141
ab_out = cast("SupportsSyncCodec", aa_codec)._decode_sync(ab_out, spec)
136142

137143
return ab_out # type: ignore[no-any-return]
@@ -146,7 +152,7 @@ def encode_chunk(
146152
"""
147153
aa_out: Any = chunk_array
148154

149-
for aa_codec, spec in zip(self._aa_codecs, self._aa_specs, strict=True):
155+
for aa_codec, spec in self.layers:
150156
if aa_out is None:
151157
return None
152158
aa_out = cast("SupportsSyncCodec", aa_codec)._encode_sync(aa_out, spec)
@@ -158,7 +164,7 @@ def encode_chunk(
158164
for bb_codec in self._bb_codecs:
159165
if bb_out is None:
160166
return None
161-
bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, self._bb_spec)
167+
bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, self._ab_spec)
162168

163169
return bb_out # type: ignore[no-any-return]
164170

@@ -369,7 +375,7 @@ async def read_batch(
369375
out[out_selection] = fill_value_or_default(chunk_spec)
370376
else:
371377
chunk_bytes_batch = await concurrent_map(
372-
[(byte_getter, array_spec.prototype) for byte_getter, array_spec, *_ in batch_info],
378+
[(byte_getter, chunk_spec.prototype) for byte_getter, chunk_spec, *_ in batch_info],
373379
lambda byte_getter, prototype: byte_getter.get(prototype),
374380
config.get("async.concurrency"),
375381
)

tests/test_sync_codec_pipeline.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from zarr.codecs.zstd import ZstdCodec
1111
from zarr.core.array_spec import ArrayConfig, ArraySpec
1212
from zarr.core.buffer import NDBuffer, default_buffer_prototype
13-
from zarr.core.codec_pipeline import CodecChain
13+
from zarr.core.codec_pipeline import ChunkTransform
1414
from zarr.core.dtype import get_data_type_from_native_dtype
1515

1616

@@ -29,37 +29,65 @@ def _make_nd_buffer(arr: np.ndarray[Any, np.dtype[Any]]) -> NDBuffer:
2929
return default_buffer_prototype().nd_buffer.from_numpy_array(arr)
3030

3131

32-
class TestCodecChain:
32+
class TestChunkTransform:
3333
def test_all_sync(self) -> None:
3434
spec = _make_array_spec((100,), np.dtype("float64"))
35-
chain = CodecChain((BytesCodec(),), spec)
35+
chain = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)
3636
assert chain.all_sync is True
3737

3838
def test_all_sync_with_compression(self) -> None:
3939
spec = _make_array_spec((100,), np.dtype("float64"))
40-
chain = CodecChain((BytesCodec(), GzipCodec()), spec)
40+
chain = ChunkTransform(codecs=(BytesCodec(), GzipCodec()), array_spec=spec)
4141
assert chain.all_sync is True
4242

4343
def test_all_sync_full_chain(self) -> None:
4444
spec = _make_array_spec((3, 4), np.dtype("float64"))
45-
chain = CodecChain((TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()), spec)
45+
chain = ChunkTransform(
46+
codecs=(TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()), array_spec=spec
47+
)
4648
assert chain.all_sync is True
4749

4850
def test_encode_decode_roundtrip_bytes_only(self) -> None:
4951
arr = np.arange(100, dtype="float64")
5052
spec = _make_array_spec(arr.shape, arr.dtype)
51-
chain = CodecChain((BytesCodec(),), spec)
53+
chain = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)
5254
nd_buf = _make_nd_buffer(arr)
5355

5456
encoded = chain.encode_chunk(nd_buf)
5557
assert encoded is not None
5658
decoded = chain.decode_chunk(encoded)
5759
np.testing.assert_array_equal(arr, decoded.as_numpy_array())
5860

61+
def test_layers_no_aa_codecs(self) -> None:
62+
spec = _make_array_spec((100,), np.dtype("float64"))
63+
chunk = ChunkTransform(codecs=(BytesCodec(), GzipCodec()), array_spec=spec)
64+
assert chunk.layers == ()
65+
66+
def test_layers_with_transpose(self) -> None:
67+
spec = _make_array_spec((3, 4), np.dtype("float64"))
68+
transpose = TransposeCodec(order=(1, 0))
69+
chunk = ChunkTransform(codecs=(transpose, BytesCodec(), ZstdCodec()), array_spec=spec)
70+
assert len(chunk.layers) == 1
71+
assert chunk.layers[0][0] is transpose
72+
assert chunk.layers[0][1] is spec
73+
74+
def test_shape_dtype_no_aa_codecs(self) -> None:
75+
spec = _make_array_spec((100,), np.dtype("float64"))
76+
chunk = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)
77+
assert chunk.shape == (100,)
78+
assert chunk.dtype == spec.dtype
79+
80+
def test_shape_dtype_with_transpose(self) -> None:
81+
spec = _make_array_spec((3, 4), np.dtype("float64"))
82+
chunk = ChunkTransform(codecs=(TransposeCodec(order=(1, 0)), BytesCodec()), array_spec=spec)
83+
# After transpose (1,0), shape (3,4) becomes (4,3)
84+
assert chunk.shape == (4, 3)
85+
assert chunk.dtype == spec.dtype
86+
5987
def test_encode_decode_roundtrip_with_compression(self) -> None:
6088
arr = np.arange(100, dtype="float64")
6189
spec = _make_array_spec(arr.shape, arr.dtype)
62-
chain = CodecChain((BytesCodec(), GzipCodec(level=1)), spec)
90+
chain = ChunkTransform(codecs=(BytesCodec(), GzipCodec(level=1)), array_spec=spec)
6391
nd_buf = _make_nd_buffer(arr)
6492

6593
encoded = chain.encode_chunk(nd_buf)
@@ -70,7 +98,10 @@ def test_encode_decode_roundtrip_with_compression(self) -> None:
7098
def test_encode_decode_roundtrip_with_transpose(self) -> None:
7199
arr = np.arange(12, dtype="float64").reshape(3, 4)
72100
spec = _make_array_spec(arr.shape, arr.dtype)
73-
chain = CodecChain((TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)), spec)
101+
chain = ChunkTransform(
102+
codecs=(TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)),
103+
array_spec=spec,
104+
)
74105
nd_buf = _make_nd_buffer(arr)
75106

76107
encoded = chain.encode_chunk(nd_buf)

0 commit comments

Comments
 (0)