Skip to content

Commit d22b6f0

Browse files
committed
chunktransform requires sync codecs
1 parent 4b22c46 commit d22b6f0

2 files changed

Lines changed: 158 additions & 26 deletions

File tree

src/zarr/core/codec_pipeline.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from dataclasses import dataclass, field
44
from itertools import islice, pairwise
5-
from typing import TYPE_CHECKING, Any, TypeVar, cast
5+
from typing import TYPE_CHECKING, Any, TypeVar
66
from warnings import warn
77

88
from zarr.abc.codec import (
@@ -80,6 +80,9 @@ class ChunkTransform:
8080
The chunk's ``shape`` and ``dtype`` reflect the representation
8181
**after** all ArrayArrayCodec layers have been applied — i.e. the
8282
spec that feeds the ArrayBytesCodec.
83+
84+
All codecs must implement ``SupportsSyncCodec``. Construction will
85+
raise ``TypeError`` if any codec does not.
8386
"""
8487

8588
codecs: tuple[Codec, ...]
@@ -92,9 +95,15 @@ class ChunkTransform:
9295
_ab_codec: ArrayBytesCodec = field(init=False, repr=False, compare=False)
9396
_ab_spec: ArraySpec = field(init=False, repr=False, compare=False)
9497
_bb_codecs: tuple[BytesBytesCodec, ...] = field(init=False, repr=False, compare=False)
95-
_all_sync: bool = field(init=False, repr=False, compare=False)
9698

9799
def __post_init__(self) -> None:
100+
non_sync = [c for c in self.codecs if not isinstance(c, SupportsSyncCodec)]
101+
if non_sync:
102+
names = ", ".join(type(c).__name__ for c in non_sync)
103+
raise TypeError(
104+
f"All codecs must implement SupportsSyncCodec. The following do not: {names}"
105+
)
106+
98107
aa, ab, bb = codecs_from_list(list(self.codecs))
99108

100109
layers: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = ()
@@ -107,7 +116,6 @@ def __post_init__(self) -> None:
107116
self._ab_codec = ab
108117
self._ab_spec = spec
109118
self._bb_codecs = bb
110-
self._all_sync = all(isinstance(c, SupportsSyncCodec) for c in self.codecs)
111119

112120
@property
113121
def shape(self) -> tuple[int, ...]:
@@ -119,26 +127,22 @@ def dtype(self) -> ZDType[TBaseDType, TBaseScalar]:
119127
"""Dtype after all ArrayArrayCodec layers (input to the ArrayBytesCodec)."""
120128
return self._ab_spec.dtype
121129

122-
@property
123-
def all_sync(self) -> bool:
124-
return self._all_sync
125-
126130
def decode_chunk(
127131
self,
128132
chunk_bytes: Buffer,
129133
) -> NDBuffer:
130134
"""Decode a single chunk through the full codec chain, synchronously.
131135
132-
Pure compute -- no IO. Only callable when all codecs support sync.
136+
Pure compute -- no IO.
133137
"""
134138
bb_out: Any = chunk_bytes
135139
for bb_codec in reversed(self._bb_codecs):
136-
bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, self._ab_spec)
140+
bb_out = bb_codec._decode_sync(bb_out, self._ab_spec) # type: ignore[union-attr]
137141

138-
ab_out: Any = cast("SupportsSyncCodec", self._ab_codec)._decode_sync(bb_out, self._ab_spec)
142+
ab_out: Any = self._ab_codec._decode_sync(bb_out, self._ab_spec) # type: ignore[union-attr]
139143

140144
for aa_codec, spec in reversed(self.layers):
141-
ab_out = cast("SupportsSyncCodec", aa_codec)._decode_sync(ab_out, spec)
145+
ab_out = aa_codec._decode_sync(ab_out, spec) # type: ignore[union-attr]
142146

143147
return ab_out # type: ignore[no-any-return]
144148

@@ -148,23 +152,23 @@ def encode_chunk(
148152
) -> Buffer | None:
149153
"""Encode a single chunk through the full codec chain, synchronously.
150154
151-
Pure compute -- no IO. Only callable when all codecs support sync.
155+
Pure compute -- no IO.
152156
"""
153157
aa_out: Any = chunk_array
154158

155159
for aa_codec, spec in self.layers:
156160
if aa_out is None:
157161
return None
158-
aa_out = cast("SupportsSyncCodec", aa_codec)._encode_sync(aa_out, spec)
162+
aa_out = aa_codec._encode_sync(aa_out, spec) # type: ignore[union-attr]
159163

160164
if aa_out is None:
161165
return None
162-
bb_out: Any = cast("SupportsSyncCodec", self._ab_codec)._encode_sync(aa_out, self._ab_spec)
166+
bb_out: Any = self._ab_codec._encode_sync(aa_out, self._ab_spec) # type: ignore[union-attr]
163167

164168
for bb_codec in self._bb_codecs:
165169
if bb_out is None:
166170
return None
167-
bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, self._ab_spec)
171+
bb_out = bb_codec._encode_sync(bb_out, self._ab_spec) # type: ignore[union-attr]
168172

169173
return bb_out # type: ignore[no-any-return]
170174

tests/test_sync_codec_pipeline.py

Lines changed: 139 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
from typing import Any
44

55
import numpy as np
6+
import pytest
67

8+
from zarr.abc.codec import ArrayBytesCodec
79
from zarr.codecs.bytes import BytesCodec
10+
from zarr.codecs.crc32c_ import Crc32cCodec
811
from zarr.codecs.gzip import GzipCodec
912
from zarr.codecs.transpose import TransposeCodec
1013
from zarr.codecs.zstd import ZstdCodec
1114
from zarr.core.array_spec import ArrayConfig, ArraySpec
12-
from zarr.core.buffer import NDBuffer, default_buffer_prototype
15+
from zarr.core.buffer import Buffer, NDBuffer, default_buffer_prototype
1316
from zarr.core.codec_pipeline import ChunkTransform
1417
from zarr.core.dtype import get_data_type_from_native_dtype
1518

@@ -30,24 +33,26 @@ def _make_nd_buffer(arr: np.ndarray[Any, np.dtype[Any]]) -> NDBuffer:
3033

3134

3235
class TestChunkTransform:
33-
def test_all_sync(self) -> None:
36+
def test_construction_bytes_only(self) -> None:
37+
# Construction succeeds when all codecs implement SupportsSyncCodec.
3438
spec = _make_array_spec((100,), np.dtype("float64"))
35-
chain = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)
36-
assert chain.all_sync is True
39+
ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)
3740

38-
def test_all_sync_with_compression(self) -> None:
41+
def test_construction_with_compression(self) -> None:
42+
# AB + BB codec chain where both implement SupportsSyncCodec.
3943
spec = _make_array_spec((100,), np.dtype("float64"))
40-
chain = ChunkTransform(codecs=(BytesCodec(), GzipCodec()), array_spec=spec)
41-
assert chain.all_sync is True
44+
ChunkTransform(codecs=(BytesCodec(), GzipCodec()), array_spec=spec)
4245

43-
def test_all_sync_full_chain(self) -> None:
46+
def test_construction_full_chain(self) -> None:
47+
# All three codec types (AA + AB + BB), all implementing SupportsSyncCodec.
4448
spec = _make_array_spec((3, 4), np.dtype("float64"))
45-
chain = ChunkTransform(
49+
ChunkTransform(
4650
codecs=(TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()), array_spec=spec
4751
)
48-
assert chain.all_sync is True
4952

5053
def test_encode_decode_roundtrip_bytes_only(self) -> None:
54+
# Minimal round-trip: BytesCodec serializes the array to bytes and back.
55+
# No compression, no AA transform.
5156
arr = np.arange(100, dtype="float64")
5257
spec = _make_array_spec(arr.shape, arr.dtype)
5358
chain = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)
@@ -59,11 +64,14 @@ def test_encode_decode_roundtrip_bytes_only(self) -> None:
5964
np.testing.assert_array_equal(arr, decoded.as_numpy_array())
6065

6166
def test_layers_no_aa_codecs(self) -> None:
67+
# When there are no ArrayArrayCodecs, layers should be empty.
6268
spec = _make_array_spec((100,), np.dtype("float64"))
6369
chunk = ChunkTransform(codecs=(BytesCodec(), GzipCodec()), array_spec=spec)
6470
assert chunk.layers == ()
6571

6672
def test_layers_with_transpose(self) -> None:
73+
# With one AA codec (TransposeCodec), layers should contain exactly one
74+
# entry pairing the codec with its input ArraySpec.
6775
spec = _make_array_spec((3, 4), np.dtype("float64"))
6876
transpose = TransposeCodec(order=(1, 0))
6977
chunk = ChunkTransform(codecs=(transpose, BytesCodec(), ZstdCodec()), array_spec=spec)
@@ -72,19 +80,24 @@ def test_layers_with_transpose(self) -> None:
7280
assert chunk.layers[0][1] is spec
7381

7482
def test_shape_dtype_no_aa_codecs(self) -> None:
83+
# Without AA codecs, shape and dtype should match the input ArraySpec
84+
# (no transforms applied before the AB codec).
7585
spec = _make_array_spec((100,), np.dtype("float64"))
7686
chunk = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)
7787
assert chunk.shape == (100,)
7888
assert chunk.dtype == spec.dtype
7989

8090
def test_shape_dtype_with_transpose(self) -> None:
91+
# TransposeCodec(order=(1,0)) on a (3, 4) array produces (4, 3).
92+
# shape/dtype reflect what the AB codec sees after all AA transforms.
8193
spec = _make_array_spec((3, 4), np.dtype("float64"))
8294
chunk = ChunkTransform(codecs=(TransposeCodec(order=(1, 0)), BytesCodec()), array_spec=spec)
83-
# After transpose (1,0), shape (3,4) becomes (4,3)
8495
assert chunk.shape == (4, 3)
8596
assert chunk.dtype == spec.dtype
8697

8798
def test_encode_decode_roundtrip_with_compression(self) -> None:
99+
# Round-trip with a BB codec (GzipCodec) to verify that bytes-bytes
100+
# compression/decompression is wired correctly.
88101
arr = np.arange(100, dtype="float64")
89102
spec = _make_array_spec(arr.shape, arr.dtype)
90103
chain = ChunkTransform(codecs=(BytesCodec(), GzipCodec(level=1)), array_spec=spec)
@@ -96,6 +109,9 @@ def test_encode_decode_roundtrip_with_compression(self) -> None:
96109
np.testing.assert_array_equal(arr, decoded.as_numpy_array())
97110

98111
def test_encode_decode_roundtrip_with_transpose(self) -> None:
112+
# Full AA + AB + BB chain round-trip. Transpose permutes axes on encode,
113+
# then BytesCodec serializes, then ZstdCodec compresses. Decode reverses
114+
# all three stages. Verifies the full pipeline works end to end.
99115
arr = np.arange(12, dtype="float64").reshape(3, 4)
100116
spec = _make_array_spec(arr.shape, arr.dtype)
101117
chain = ChunkTransform(
@@ -108,3 +124,115 @@ def test_encode_decode_roundtrip_with_transpose(self) -> None:
108124
assert encoded is not None
109125
decoded = chain.decode_chunk(encoded)
110126
np.testing.assert_array_equal(arr, decoded.as_numpy_array())
127+
128+
def test_rejects_non_sync_codec(self) -> None:
129+
# Construction must raise TypeError when a codec lacks SupportsSyncCodec.
130+
131+
class AsyncOnlyCodec(ArrayBytesCodec):
132+
is_fixed_size = True
133+
134+
async def _decode_single(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> NDBuffer:
135+
raise NotImplementedError # pragma: no cover
136+
137+
async def _encode_single(
138+
self, chunk_array: NDBuffer, chunk_spec: ArraySpec
139+
) -> Buffer | None:
140+
raise NotImplementedError # pragma: no cover
141+
142+
def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int:
143+
return input_byte_length # pragma: no cover
144+
145+
spec = _make_array_spec((100,), np.dtype("float64"))
146+
with pytest.raises(TypeError, match="AsyncOnlyCodec"):
147+
ChunkTransform(codecs=(AsyncOnlyCodec(),), array_spec=spec)
148+
149+
def test_rejects_mixed_sync_and_non_sync(self) -> None:
150+
# Even if some codecs support sync, a single non-sync codec should
151+
# cause construction to fail.
152+
153+
class AsyncOnlyCodec(ArrayBytesCodec):
154+
is_fixed_size = True
155+
156+
async def _decode_single(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> NDBuffer:
157+
raise NotImplementedError # pragma: no cover
158+
159+
async def _encode_single(
160+
self, chunk_array: NDBuffer, chunk_spec: ArraySpec
161+
) -> Buffer | None:
162+
raise NotImplementedError # pragma: no cover
163+
164+
def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int:
165+
return input_byte_length # pragma: no cover
166+
167+
spec = _make_array_spec((3, 4), np.dtype("float64"))
168+
with pytest.raises(TypeError, match="AsyncOnlyCodec"):
169+
ChunkTransform(
170+
codecs=(TransposeCodec(order=(1, 0)), AsyncOnlyCodec()),
171+
array_spec=spec,
172+
)
173+
174+
def test_compute_encoded_size_bytes_only(self) -> None:
175+
# BytesCodec is size-preserving: encoded size == input size.
176+
spec = _make_array_spec((100,), np.dtype("float64"))
177+
chain = ChunkTransform(codecs=(BytesCodec(),), array_spec=spec)
178+
assert chain.compute_encoded_size(800, spec) == 800
179+
180+
def test_compute_encoded_size_with_crc32c(self) -> None:
181+
# Crc32cCodec appends a 4-byte checksum, so encoded size = input + 4.
182+
spec = _make_array_spec((100,), np.dtype("float64"))
183+
chain = ChunkTransform(codecs=(BytesCodec(), Crc32cCodec()), array_spec=spec)
184+
assert chain.compute_encoded_size(800, spec) == 804
185+
186+
def test_compute_encoded_size_with_transpose(self) -> None:
187+
# TransposeCodec reorders axes but doesn't change the byte count.
188+
# Verifies that compute_encoded_size walks through AA codecs correctly.
189+
spec = _make_array_spec((3, 4), np.dtype("float64"))
190+
chain = ChunkTransform(codecs=(TransposeCodec(order=(1, 0)), BytesCodec()), array_spec=spec)
191+
assert chain.compute_encoded_size(96, spec) == 96
192+
193+
def test_encode_chunk_returns_none_propagation(self) -> None:
194+
# When an AA codec returns None (signaling "this chunk is the fill value,
195+
# don't store it"), encode_chunk must short-circuit and return None
196+
# instead of passing None into the next codec.
197+
198+
class NoneReturningAACodec(TransposeCodec):
199+
"""An ArrayArrayCodec that always returns None from encode."""
200+
201+
def _encode_sync(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer | None:
202+
return None
203+
204+
spec = _make_array_spec((3, 4), np.dtype("float64"))
205+
chain = ChunkTransform(
206+
codecs=(NoneReturningAACodec(order=(1, 0)), BytesCodec()),
207+
array_spec=spec,
208+
)
209+
arr = np.arange(12, dtype="float64").reshape(3, 4)
210+
nd_buf = _make_nd_buffer(arr)
211+
assert chain.encode_chunk(nd_buf) is None
212+
213+
def test_encode_decode_roundtrip_with_crc32c(self) -> None:
214+
# Round-trip through BytesCodec + Crc32cCodec. Crc32c appends a checksum
215+
# on encode and verifies it on decode, so this tests that the BB codec
216+
# pipeline runs correctly in both directions.
217+
arr = np.arange(100, dtype="float64")
218+
spec = _make_array_spec(arr.shape, arr.dtype)
219+
chain = ChunkTransform(codecs=(BytesCodec(), Crc32cCodec()), array_spec=spec)
220+
nd_buf = _make_nd_buffer(arr)
221+
222+
encoded = chain.encode_chunk(nd_buf)
223+
assert encoded is not None
224+
decoded = chain.decode_chunk(encoded)
225+
np.testing.assert_array_equal(arr, decoded.as_numpy_array())
226+
227+
def test_encode_decode_roundtrip_int32(self) -> None:
228+
# Round-trip with int32 data to verify that the codec chain is not
229+
# float-specific. Exercises a different dtype path through BytesCodec.
230+
arr = np.arange(50, dtype="int32")
231+
spec = _make_array_spec(arr.shape, arr.dtype)
232+
chain = ChunkTransform(codecs=(BytesCodec(), ZstdCodec(level=1)), array_spec=spec)
233+
nd_buf = _make_nd_buffer(arr)
234+
235+
encoded = chain.encode_chunk(nd_buf)
236+
assert encoded is not None
237+
decoded = chain.decode_chunk(encoded)
238+
np.testing.assert_array_equal(arr, decoded.as_numpy_array())

0 commit comments

Comments
 (0)