Skip to content

Commit c9b534a

Browse files
d-v-bclaude
andauthored
perf/chunktransform (zarr-developers#3722)
* add sync methods to codecs * add CodecChain dataclass and sync codec tests Introduces CodecChain, a frozen dataclass that chains array-array, array-bytes, and bytes-bytes codecs with synchronous encode/decode methods. Pure compute only -- no IO, no threading, no batching. Also adds sync roundtrip tests for individual codecs (blosc, gzip, zstd, crc32c, bytes, transpose, vlen) and CodecChain integration tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor codecchain * separate codecs and specs * add synchronous methods to stores * chunktransform * remove memorystore changes * chunktransform requires sync codecs * lint * simplify chunktransform by remove layers * rename to encode / decode * docs: improve type: ignore explanations * refactor: SupportsSyncCodec is generic, like BaseCodec * chore: remove shape and dtype attributes * test: update tests * test: clean up tests * chore: remove type: ignores --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0081fef commit c9b534a

File tree

3 files changed

+259
-9
lines changed

3 files changed

+259
-9
lines changed

src/zarr/abc/codec.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,19 @@ def _check_codecjson_v2(data: object) -> TypeGuard[CodecJSON_V2[str]]:
6767

6868

6969
@runtime_checkable
70-
class SupportsSyncCodec(Protocol):
70+
class SupportsSyncCodec[CI: CodecInput, CO: CodecOutput](Protocol):
7171
"""Protocol for codecs that support synchronous encode/decode.
7272
73-
Codecs implementing this protocol provide ``_decode_sync`` and ``_encode_sync``
73+
Codecs implementing this protocol provide `_decode_sync` and `_encode_sync`
7474
methods that perform encoding/decoding without requiring an async event loop.
75+
76+
The type parameters mirror `BaseCodec`: `CI` is the decoded type and `CO` is
77+
the encoded type.
7578
"""
7679

77-
def _decode_sync(
78-
self, chunk_data: NDBuffer | Buffer, chunk_spec: ArraySpec
79-
) -> NDBuffer | Buffer: ...
80+
def _decode_sync(self, chunk_data: CO, chunk_spec: ArraySpec) -> CI: ...
8081

81-
def _encode_sync(
82-
self, chunk_data: NDBuffer | Buffer, chunk_spec: ArraySpec
83-
) -> NDBuffer | Buffer | None: ...
82+
def _encode_sync(self, chunk_data: CI, chunk_spec: ArraySpec) -> CO | None: ...
8483

8584

8685
class BaseCodec[CI: CodecInput, CO: CodecOutput](Metadata):

src/zarr/core/codec_pipeline.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from dataclasses import dataclass
3+
from dataclasses import dataclass, field
44
from itertools import islice, pairwise
55
from typing import TYPE_CHECKING, Any
66
from warnings import warn
@@ -14,6 +14,7 @@
1414
Codec,
1515
CodecPipeline,
1616
GetResult,
17+
SupportsSyncCodec,
1718
)
1819
from zarr.core.common import concurrent_map
1920
from zarr.core.config import config
@@ -66,6 +67,111 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any:
6667
return fill_value
6768

6869

70+
@dataclass(slots=True, kw_only=True)
71+
class ChunkTransform:
72+
"""A synchronous codec chain bound to an ArraySpec.
73+
74+
Provides `encode` and `decode` for pure-compute codec operations
75+
(no IO, no threading, no batching).
76+
77+
All codecs must implement `SupportsSyncCodec`. Construction will
78+
raise `TypeError` if any codec does not.
79+
"""
80+
81+
codecs: tuple[Codec, ...]
82+
array_spec: ArraySpec
83+
84+
# (sync codec, input_spec) pairs in pipeline order.
85+
_aa_codecs: tuple[tuple[SupportsSyncCodec[NDBuffer, NDBuffer], ArraySpec], ...] = field(
86+
init=False, repr=False, compare=False
87+
)
88+
_ab_codec: SupportsSyncCodec[NDBuffer, Buffer] = field(init=False, repr=False, compare=False)
89+
_ab_spec: ArraySpec = field(init=False, repr=False, compare=False)
90+
_bb_codecs: tuple[SupportsSyncCodec[Buffer, Buffer], ...] = field(
91+
init=False, repr=False, compare=False
92+
)
93+
94+
def __post_init__(self) -> None:
95+
non_sync = [c for c in self.codecs if not isinstance(c, SupportsSyncCodec)]
96+
if non_sync:
97+
names = ", ".join(type(c).__name__ for c in non_sync)
98+
raise TypeError(
99+
f"All codecs must implement SupportsSyncCodec. The following do not: {names}"
100+
)
101+
102+
aa, ab, bb = codecs_from_list(list(self.codecs))
103+
104+
aa_codecs: list[tuple[SupportsSyncCodec[NDBuffer, NDBuffer], ArraySpec]] = []
105+
spec = self.array_spec
106+
for aa_codec in aa:
107+
assert isinstance(aa_codec, SupportsSyncCodec)
108+
aa_codecs.append((aa_codec, spec))
109+
spec = aa_codec.resolve_metadata(spec)
110+
111+
self._aa_codecs = tuple(aa_codecs)
112+
assert isinstance(ab, SupportsSyncCodec)
113+
self._ab_codec = ab
114+
self._ab_spec = spec
115+
bb_sync: list[SupportsSyncCodec[Buffer, Buffer]] = []
116+
for bb_codec in bb:
117+
assert isinstance(bb_codec, SupportsSyncCodec)
118+
bb_sync.append(bb_codec)
119+
self._bb_codecs = tuple(bb_sync)
120+
121+
def decode(
122+
self,
123+
chunk_bytes: Buffer,
124+
) -> NDBuffer:
125+
"""Decode a single chunk through the full codec chain, synchronously.
126+
127+
Pure compute -- no IO.
128+
"""
129+
data: Buffer = chunk_bytes
130+
for bb_codec in reversed(self._bb_codecs):
131+
data = bb_codec._decode_sync(data, self._ab_spec)
132+
133+
chunk_array: NDBuffer = self._ab_codec._decode_sync(data, self._ab_spec)
134+
135+
for aa_codec, spec in reversed(self._aa_codecs):
136+
chunk_array = aa_codec._decode_sync(chunk_array, spec)
137+
138+
return chunk_array
139+
140+
def encode(
141+
self,
142+
chunk_array: NDBuffer,
143+
) -> Buffer | None:
144+
"""Encode a single chunk through the full codec chain, synchronously.
145+
146+
Pure compute -- no IO.
147+
"""
148+
aa_data: NDBuffer = chunk_array
149+
for aa_codec, spec in self._aa_codecs:
150+
aa_result = aa_codec._encode_sync(aa_data, spec)
151+
if aa_result is None:
152+
return None
153+
aa_data = aa_result
154+
155+
ab_result = self._ab_codec._encode_sync(aa_data, self._ab_spec)
156+
if ab_result is None:
157+
return None
158+
159+
bb_data: Buffer = ab_result
160+
for bb_codec in self._bb_codecs:
161+
bb_result = bb_codec._encode_sync(bb_data, self._ab_spec)
162+
if bb_result is None:
163+
return None
164+
bb_data = bb_result
165+
166+
return bb_data
167+
168+
def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
169+
for codec in self.codecs:
170+
byte_length = codec.compute_encoded_size(byte_length, array_spec)
171+
array_spec = codec.resolve_metadata(array_spec)
172+
return byte_length
173+
174+
69175
@dataclass(frozen=True)
70176
class BatchedCodecPipeline(CodecPipeline):
71177
"""Default codec pipeline.

tests/test_sync_codec_pipeline.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
import numpy as np
6+
import pytest
7+
8+
from zarr.abc.codec import ArrayBytesCodec, Codec
9+
from zarr.codecs.bytes import BytesCodec
10+
from zarr.codecs.crc32c_ import Crc32cCodec
11+
from zarr.codecs.gzip import GzipCodec
12+
from zarr.codecs.transpose import TransposeCodec
13+
from zarr.codecs.zstd import ZstdCodec
14+
from zarr.core.array_spec import ArrayConfig, ArraySpec
15+
from zarr.core.buffer import Buffer, NDBuffer, default_buffer_prototype
16+
from zarr.core.codec_pipeline import ChunkTransform
17+
from zarr.core.dtype import get_data_type_from_native_dtype
18+
19+
20+
class AsyncOnlyCodec(ArrayBytesCodec):
21+
"""A codec that only supports async, for testing rejection of non-sync codecs."""
22+
23+
is_fixed_size = True
24+
25+
async def _decode_single(self, chunk_data: Buffer, chunk_spec: ArraySpec) -> NDBuffer:
26+
raise NotImplementedError # pragma: no cover
27+
28+
async def _encode_single(self, chunk_data: NDBuffer, chunk_spec: ArraySpec) -> Buffer | None:
29+
raise NotImplementedError # pragma: no cover
30+
31+
def compute_encoded_size(self, input_byte_length: int, chunk_spec: ArraySpec) -> int:
32+
return input_byte_length # pragma: no cover
33+
34+
35+
def _make_array_spec(shape: tuple[int, ...], dtype: np.dtype[np.generic]) -> ArraySpec:
36+
zdtype = get_data_type_from_native_dtype(dtype)
37+
return ArraySpec(
38+
shape=shape,
39+
dtype=zdtype,
40+
fill_value=zdtype.cast_scalar(0),
41+
config=ArrayConfig(order="C", write_empty_chunks=True),
42+
prototype=default_buffer_prototype(),
43+
)
44+
45+
46+
def _make_nd_buffer(arr: np.ndarray[Any, np.dtype[Any]]) -> NDBuffer:
47+
return default_buffer_prototype().nd_buffer.from_numpy_array(arr)
48+
49+
50+
@pytest.mark.parametrize(
51+
("shape", "codecs"),
52+
[
53+
((100,), (BytesCodec(),)),
54+
((100,), (BytesCodec(), GzipCodec())),
55+
((3, 4), (TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec())),
56+
],
57+
ids=["bytes-only", "with-compression", "full-chain"],
58+
)
59+
def test_construction(shape: tuple[int, ...], codecs: tuple[Codec, ...]) -> None:
60+
"""Construction succeeds when all codecs implement SupportsSyncCodec."""
61+
spec = _make_array_spec(shape, np.dtype("float64"))
62+
ChunkTransform(codecs=codecs, array_spec=spec)
63+
64+
65+
@pytest.mark.parametrize(
66+
("shape", "codecs"),
67+
[
68+
((100,), (AsyncOnlyCodec(),)),
69+
((3, 4), (TransposeCodec(order=(1, 0)), AsyncOnlyCodec())),
70+
],
71+
ids=["async-only", "mixed-sync-and-async"],
72+
)
73+
def test_construction_rejects_non_sync(shape: tuple[int, ...], codecs: tuple[Codec, ...]) -> None:
74+
"""Construction raises TypeError when any codec lacks SupportsSyncCodec."""
75+
spec = _make_array_spec(shape, np.dtype("float64"))
76+
with pytest.raises(TypeError, match="AsyncOnlyCodec"):
77+
ChunkTransform(codecs=codecs, array_spec=spec)
78+
79+
80+
@pytest.mark.parametrize(
81+
("arr", "codecs"),
82+
[
83+
(np.arange(100, dtype="float64"), (BytesCodec(),)),
84+
(np.arange(100, dtype="float64"), (BytesCodec(), GzipCodec(level=1))),
85+
(
86+
np.arange(12, dtype="float64").reshape(3, 4),
87+
(TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)),
88+
),
89+
(np.arange(100, dtype="float64"), (BytesCodec(), Crc32cCodec())),
90+
(np.arange(50, dtype="int32"), (BytesCodec(), ZstdCodec(level=1))),
91+
],
92+
ids=["bytes-only", "gzip", "transpose+zstd", "crc32c", "int32"],
93+
)
94+
def test_encode_decode_roundtrip(
95+
arr: np.ndarray[Any, np.dtype[Any]], codecs: tuple[Codec, ...]
96+
) -> None:
97+
"""Data survives a full encode/decode cycle."""
98+
spec = _make_array_spec(arr.shape, arr.dtype)
99+
chain = ChunkTransform(codecs=codecs, array_spec=spec)
100+
nd_buf = _make_nd_buffer(arr)
101+
102+
encoded = chain.encode(nd_buf)
103+
assert encoded is not None
104+
decoded = chain.decode(encoded)
105+
np.testing.assert_array_equal(arr, decoded.as_numpy_array())
106+
107+
108+
@pytest.mark.parametrize(
109+
("shape", "codecs", "input_size", "expected_size"),
110+
[
111+
((100,), (BytesCodec(),), 800, 800),
112+
((100,), (BytesCodec(), Crc32cCodec()), 800, 804),
113+
((3, 4), (TransposeCodec(order=(1, 0)), BytesCodec()), 96, 96),
114+
],
115+
ids=["bytes-only", "crc32c", "transpose"],
116+
)
117+
def test_compute_encoded_size(
118+
shape: tuple[int, ...],
119+
codecs: tuple[Codec, ...],
120+
input_size: int,
121+
expected_size: int,
122+
) -> None:
123+
"""compute_encoded_size returns the correct byte length."""
124+
spec = _make_array_spec(shape, np.dtype("float64"))
125+
chain = ChunkTransform(codecs=codecs, array_spec=spec)
126+
assert chain.compute_encoded_size(input_size, spec) == expected_size
127+
128+
129+
def test_encode_returns_none_propagation() -> None:
130+
"""When an AA codec returns None, encode short-circuits and returns None."""
131+
132+
class NoneReturningAACodec(TransposeCodec):
133+
"""An ArrayArrayCodec that always returns None from encode."""
134+
135+
def _encode_sync(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> NDBuffer | None:
136+
return None
137+
138+
spec = _make_array_spec((3, 4), np.dtype("float64"))
139+
chain = ChunkTransform(
140+
codecs=(NoneReturningAACodec(order=(1, 0)), BytesCodec()),
141+
array_spec=spec,
142+
)
143+
arr = np.arange(12, dtype="float64").reshape(3, 4)
144+
nd_buf = _make_nd_buffer(arr)
145+
assert chain.encode(nd_buf) is None

0 commit comments

Comments
 (0)