Skip to content

Commit 41b7a6a

Browse files
committed
refactor codecchain
1 parent cd4efb0 commit 41b7a6a

File tree

2 files changed

+65
-172
lines changed

2 files changed

+65
-172
lines changed

src/zarr/core/codec_pipeline.py

Lines changed: 42 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -69,133 +69,103 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any:
6969
return fill_value
7070

7171

72-
@dataclass(frozen=True)
72+
@dataclass(frozen=True, slots=True)
7373
class CodecChain:
74-
"""Lightweight codec chain: array-array -> array-bytes -> bytes-bytes.
74+
"""Codec chain with pre-resolved metadata specs.
7575
76-
Pure compute only -- no IO methods, no threading, no batching.
76+
Constructed from an iterable of codecs and a chunk ArraySpec.
77+
Resolves each codec against the spec so that encode/decode can
78+
run without re-resolving. Pure compute only -- no IO, no threading,
79+
no batching.
7780
"""
7881

79-
array_array_codecs: tuple[ArrayArrayCodec, ...]
80-
array_bytes_codec: ArrayBytesCodec
81-
bytes_bytes_codecs: tuple[BytesBytesCodec, ...]
82+
codecs: tuple[Codec, ...]
83+
chunk_spec: ArraySpec
8284

83-
_all_sync: bool = field(default=False, init=False, repr=False, compare=False)
85+
_aa_codecs: tuple[tuple[ArrayArrayCodec, ArraySpec], ...] = field(
86+
init=False, repr=False, compare=False
87+
)
88+
_ab_codec: ArrayBytesCodec = field(init=False, repr=False, compare=False)
89+
_ab_spec: ArraySpec = field(init=False, repr=False, compare=False)
90+
_bb_codecs: tuple[BytesBytesCodec, ...] = field(init=False, repr=False, compare=False)
91+
_all_sync: bool = field(init=False, repr=False, compare=False)
8492

8593
def __post_init__(self) -> None:
86-
object.__setattr__(
87-
self,
88-
"_all_sync",
89-
all(isinstance(c, SupportsSyncCodec) for c in self),
90-
)
91-
92-
def __iter__(self) -> Iterator[Codec]:
93-
yield from self.array_array_codecs
94-
yield self.array_bytes_codec
95-
yield from self.bytes_bytes_codecs
94+
aa, ab, bb = codecs_from_list(list(self.codecs))
9695

97-
@classmethod
98-
def from_codecs(cls, codecs: Iterable[Codec]) -> CodecChain:
99-
aa, ab, bb = codecs_from_list(list(codecs))
100-
return cls(array_array_codecs=aa, array_bytes_codec=ab, bytes_bytes_codecs=bb)
101-
102-
def resolve_metadata_chain(
103-
self, chunk_spec: ArraySpec
104-
) -> tuple[
105-
list[tuple[ArrayArrayCodec, ArraySpec]],
106-
tuple[ArrayBytesCodec, ArraySpec],
107-
list[tuple[BytesBytesCodec, ArraySpec]],
108-
]:
109-
"""Resolve metadata through the codec chain for a single chunk_spec."""
110-
aa_codecs_with_spec: list[tuple[ArrayArrayCodec, ArraySpec]] = []
111-
spec = chunk_spec
112-
for aa_codec in self.array_array_codecs:
113-
aa_codecs_with_spec.append((aa_codec, spec))
96+
aa_pairs: list[tuple[ArrayArrayCodec, ArraySpec]] = []
97+
spec = self.chunk_spec
98+
for aa_codec in aa:
99+
aa_pairs.append((aa_codec, spec))
114100
spec = aa_codec.resolve_metadata(spec)
115101

116-
ab_codec_with_spec = (self.array_bytes_codec, spec)
117-
spec = self.array_bytes_codec.resolve_metadata(spec)
102+
object.__setattr__(self, "_aa_codecs", tuple(aa_pairs))
103+
object.__setattr__(self, "_ab_codec", ab)
104+
object.__setattr__(self, "_ab_spec", spec)
118105

119-
bb_codecs_with_spec: list[tuple[BytesBytesCodec, ArraySpec]] = []
120-
for bb_codec in self.bytes_bytes_codecs:
121-
bb_codecs_with_spec.append((bb_codec, spec))
122-
spec = bb_codec.resolve_metadata(spec)
106+
object.__setattr__(self, "_bb_codecs", bb)
123107

124-
return (aa_codecs_with_spec, ab_codec_with_spec, bb_codecs_with_spec)
108+
object.__setattr__(
109+
self,
110+
"_all_sync",
111+
all(isinstance(c, SupportsSyncCodec) for c in self.codecs),
112+
)
113+
114+
@property
115+
def all_sync(self) -> bool:
116+
return self._all_sync
125117

126118
def decode_chunk(
127119
self,
128120
chunk_bytes: Buffer,
129-
chunk_spec: ArraySpec,
130-
aa_chain: Iterable[tuple[ArrayArrayCodec, ArraySpec]] | None = None,
131-
ab_pair: tuple[ArrayBytesCodec, ArraySpec] | None = None,
132-
bb_chain: Iterable[tuple[BytesBytesCodec, ArraySpec]] | None = None,
133121
) -> NDBuffer:
134122
"""Decode a single chunk through the full codec chain, synchronously.
135123
136124
Pure compute -- no IO. Only callable when all codecs support sync.
137-
138-
The optional ``aa_chain``, ``ab_pair``, ``bb_chain`` parameters allow
139-
pre-resolved metadata to be reused across many chunks with the same spec.
140-
If not provided, ``resolve_metadata_chain`` is called internally.
141125
"""
142-
if aa_chain is None or ab_pair is None or bb_chain is None:
143-
aa_chain, ab_pair, bb_chain = self.resolve_metadata_chain(chunk_spec)
144-
145126
bb_out: Any = chunk_bytes
146-
for bb_codec, spec in reversed(list(bb_chain)):
147-
bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, spec)
127+
for bb_codec in reversed(self._bb_codecs):
128+
bb_out = cast("SupportsSyncCodec", bb_codec)._decode_sync(bb_out, self.chunk_spec)
148129

149-
ab_codec, ab_spec = ab_pair
150-
ab_out: Any = cast("SupportsSyncCodec", ab_codec)._decode_sync(bb_out, ab_spec)
130+
ab_out: Any = cast("SupportsSyncCodec", self._ab_codec)._decode_sync(bb_out, self._ab_spec)
151131

152-
for aa_codec, spec in reversed(list(aa_chain)):
132+
for aa_codec, spec in reversed(self._aa_codecs):
153133
ab_out = cast("SupportsSyncCodec", aa_codec)._decode_sync(ab_out, spec)
154134

155135
return ab_out # type: ignore[no-any-return]
156136

157137
def encode_chunk(
158138
self,
159139
chunk_array: NDBuffer,
160-
chunk_spec: ArraySpec,
161140
) -> Buffer | None:
162141
"""Encode a single chunk through the full codec chain, synchronously.
163142
164143
Pure compute -- no IO. Only callable when all codecs support sync.
165144
"""
166-
spec = chunk_spec
167145
aa_out: Any = chunk_array
168146

169-
for aa_codec in self.array_array_codecs:
147+
for aa_codec, spec in self._aa_codecs:
170148
if aa_out is None:
171149
return None
172150
aa_out = cast("SupportsSyncCodec", aa_codec)._encode_sync(aa_out, spec)
173-
spec = aa_codec.resolve_metadata(spec)
174151

175152
if aa_out is None:
176153
return None
177-
bb_out: Any = cast("SupportsSyncCodec", self.array_bytes_codec)._encode_sync(aa_out, spec)
178-
spec = self.array_bytes_codec.resolve_metadata(spec)
154+
bb_out: Any = cast("SupportsSyncCodec", self._ab_codec)._encode_sync(aa_out, self._ab_spec)
179155

180-
for bb_codec in self.bytes_bytes_codecs:
156+
for bb_codec in self._bb_codecs:
181157
if bb_out is None:
182158
return None
183-
bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, spec)
184-
spec = bb_codec.resolve_metadata(spec)
159+
bb_out = cast("SupportsSyncCodec", bb_codec)._encode_sync(bb_out, self.chunk_spec)
185160

186161
return bb_out # type: ignore[no-any-return]
187162

188163
def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
189-
for codec in self:
164+
for codec in self.codecs:
190165
byte_length = codec.compute_encoded_size(byte_length, array_spec)
191166
array_spec = codec.resolve_metadata(array_spec)
192167
return byte_length
193168

194-
def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec:
195-
for codec in self:
196-
chunk_spec = codec.resolve_metadata(chunk_spec)
197-
return chunk_spec
198-
199169

200170
@dataclass(frozen=True)
201171
class BatchedCodecPipeline(CodecPipeline):

tests/test_sync_codec_pipeline.py

Lines changed: 23 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any
3+
from typing import Any
44

55
import numpy as np
6-
import pytest
76

87
from zarr.codecs.bytes import BytesCodec
98
from zarr.codecs.gzip import GzipCodec
109
from zarr.codecs.transpose import TransposeCodec
1110
from zarr.codecs.zstd import ZstdCodec
1211
from zarr.core.array_spec import ArrayConfig, ArraySpec
1312
from zarr.core.buffer import NDBuffer, default_buffer_prototype
13+
from zarr.core.codec_pipeline import CodecChain
1414
from zarr.core.dtype import get_data_type_from_native_dtype
1515

16-
if TYPE_CHECKING:
17-
from zarr.abc.codec import Codec
18-
1916

2017
def _make_array_spec(shape: tuple[int, ...], dtype: np.dtype[np.generic]) -> ArraySpec:
2118
zdtype = get_data_type_from_native_dtype(dtype)
@@ -33,124 +30,50 @@ def _make_nd_buffer(arr: np.ndarray[Any, np.dtype[Any]]) -> NDBuffer:
3330

3431

3532
class TestCodecChain:
36-
def test_from_codecs_bytes_only(self) -> None:
37-
from zarr.core.codec_pipeline import CodecChain
38-
39-
chain = CodecChain.from_codecs([BytesCodec()])
40-
assert chain.array_array_codecs == ()
41-
assert isinstance(chain.array_bytes_codec, BytesCodec)
42-
assert chain.bytes_bytes_codecs == ()
43-
assert chain._all_sync is True
44-
45-
def test_from_codecs_with_compression(self) -> None:
46-
from zarr.core.codec_pipeline import CodecChain
47-
48-
chain = CodecChain.from_codecs([BytesCodec(), GzipCodec()])
49-
assert isinstance(chain.array_bytes_codec, BytesCodec)
50-
assert len(chain.bytes_bytes_codecs) == 1
51-
assert isinstance(chain.bytes_bytes_codecs[0], GzipCodec)
52-
assert chain._all_sync is True
53-
54-
def test_from_codecs_with_transpose(self) -> None:
55-
from zarr.core.codec_pipeline import CodecChain
56-
57-
chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec()])
58-
assert len(chain.array_array_codecs) == 1
59-
assert isinstance(chain.array_array_codecs[0], TransposeCodec)
60-
assert isinstance(chain.array_bytes_codec, BytesCodec)
61-
assert chain._all_sync is True
33+
def test_all_sync(self) -> None:
34+
spec = _make_array_spec((100,), np.dtype("float64"))
35+
chain = CodecChain((BytesCodec(),), spec)
36+
assert chain.all_sync is True
6237

63-
def test_from_codecs_full_chain(self) -> None:
64-
from zarr.core.codec_pipeline import CodecChain
38+
def test_all_sync_with_compression(self) -> None:
39+
spec = _make_array_spec((100,), np.dtype("float64"))
40+
chain = CodecChain((BytesCodec(), GzipCodec()), spec)
41+
assert chain.all_sync is True
6542

66-
chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()])
67-
assert len(chain.array_array_codecs) == 1
68-
assert isinstance(chain.array_bytes_codec, BytesCodec)
69-
assert len(chain.bytes_bytes_codecs) == 1
70-
assert chain._all_sync is True
71-
72-
def test_iter(self) -> None:
73-
from zarr.core.codec_pipeline import CodecChain
74-
75-
codecs: list[Codec] = [TransposeCodec(order=(1, 0)), BytesCodec(), GzipCodec()]
76-
chain = CodecChain.from_codecs(codecs)
77-
assert list(chain) == codecs
78-
79-
def test_frozen(self) -> None:
80-
from zarr.core.codec_pipeline import CodecChain
81-
82-
chain = CodecChain.from_codecs([BytesCodec()])
83-
with pytest.raises(AttributeError):
84-
chain.array_bytes_codec = BytesCodec() # type: ignore[misc]
43+
def test_all_sync_full_chain(self) -> None:
44+
spec = _make_array_spec((3, 4), np.dtype("float64"))
45+
chain = CodecChain((TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec()), spec)
46+
assert chain.all_sync is True
8547

8648
def test_encode_decode_roundtrip_bytes_only(self) -> None:
87-
from zarr.core.codec_pipeline import CodecChain
88-
89-
chain = CodecChain.from_codecs([BytesCodec()])
9049
arr = np.arange(100, dtype="float64")
9150
spec = _make_array_spec(arr.shape, arr.dtype)
92-
chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain])
51+
chain = CodecChain((BytesCodec(),), spec)
9352
nd_buf = _make_nd_buffer(arr)
9453

95-
encoded = chain_evolved.encode_chunk(nd_buf, spec)
54+
encoded = chain.encode_chunk(nd_buf)
9655
assert encoded is not None
97-
decoded = chain_evolved.decode_chunk(encoded, spec)
98-
assert decoded is not None
56+
decoded = chain.decode_chunk(encoded)
9957
np.testing.assert_array_equal(arr, decoded.as_numpy_array())
10058

10159
def test_encode_decode_roundtrip_with_compression(self) -> None:
102-
from zarr.core.codec_pipeline import CodecChain
103-
104-
chain = CodecChain.from_codecs([BytesCodec(), GzipCodec(level=1)])
10560
arr = np.arange(100, dtype="float64")
10661
spec = _make_array_spec(arr.shape, arr.dtype)
107-
chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain])
62+
chain = CodecChain((BytesCodec(), GzipCodec(level=1)), spec)
10863
nd_buf = _make_nd_buffer(arr)
10964

110-
encoded = chain_evolved.encode_chunk(nd_buf, spec)
65+
encoded = chain.encode_chunk(nd_buf)
11166
assert encoded is not None
112-
decoded = chain_evolved.decode_chunk(encoded, spec)
113-
assert decoded is not None
67+
decoded = chain.decode_chunk(encoded)
11468
np.testing.assert_array_equal(arr, decoded.as_numpy_array())
11569

11670
def test_encode_decode_roundtrip_with_transpose(self) -> None:
117-
from zarr.core.codec_pipeline import CodecChain
118-
119-
chain = CodecChain.from_codecs(
120-
[TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)]
121-
)
12271
arr = np.arange(12, dtype="float64").reshape(3, 4)
12372
spec = _make_array_spec(arr.shape, arr.dtype)
124-
chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain])
73+
chain = CodecChain((TransposeCodec(order=(1, 0)), BytesCodec(), ZstdCodec(level=1)), spec)
12574
nd_buf = _make_nd_buffer(arr)
12675

127-
encoded = chain_evolved.encode_chunk(nd_buf, spec)
76+
encoded = chain.encode_chunk(nd_buf)
12877
assert encoded is not None
129-
decoded = chain_evolved.decode_chunk(encoded, spec)
130-
assert decoded is not None
78+
decoded = chain.decode_chunk(encoded)
13179
np.testing.assert_array_equal(arr, decoded.as_numpy_array())
132-
133-
def test_resolve_metadata_chain(self) -> None:
134-
from zarr.core.codec_pipeline import CodecChain
135-
136-
chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec(), GzipCodec()])
137-
arr = np.zeros((3, 4), dtype="float64")
138-
spec = _make_array_spec(arr.shape, arr.dtype)
139-
chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain])
140-
141-
aa_chain, ab_pair, bb_chain = chain_evolved.resolve_metadata_chain(spec)
142-
assert len(aa_chain) == 1
143-
assert aa_chain[0][1].shape == (3, 4) # spec before transpose
144-
_ab_codec, ab_spec = ab_pair
145-
assert ab_spec.shape == (4, 3) # spec after transpose
146-
assert len(bb_chain) == 1
147-
148-
def test_resolve_metadata(self) -> None:
149-
from zarr.core.codec_pipeline import CodecChain
150-
151-
chain = CodecChain.from_codecs([TransposeCodec(order=(1, 0)), BytesCodec()])
152-
spec = _make_array_spec((3, 4), np.dtype("float64"))
153-
chain_evolved = CodecChain.from_codecs([c.evolve_from_array_spec(spec) for c in chain])
154-
resolved = chain_evolved.resolve_metadata(spec)
155-
# After transpose (1,0) + bytes, shape should reflect the transpose
156-
assert resolved.shape == (4, 3)

0 commit comments

Comments
 (0)