Skip to content

Commit 82b9e24

Browse files
committed
refactor: take chunk_spec per call in ChunkTransform
Previously, ChunkTransform cached resolved AA/AB specs at construction using whatever prototype was passed in (the default CPU one). When runtime chunk_specs arrived with a different prototype (e.g. GPU), those cached specs forced the codec chain to use the wrong prototype, blowing up in BB codecs that wrap the buffer in CPU-only paths. The previous fix added an opt-in `prototype` parameter to encode_chunk and decode_chunk, but that papered over the real problem: the cached specs were keyed on a partial input. The right model is the one BatchedCodecPipeline already uses — derive the spec chain from the runtime chunk_spec on every call. ChunkTransform now takes a chunk_spec per call. The codec chain only mutates `shape` (via TransposeCodec etc.); prototype, dtype, fill_value are invariant — so the cached spec chain is keyed on `(chunk_spec.shape, id(chunk_spec))`. In steady state (same spec reused per inner chunk), the cache hits and the call is identical to the previous fast path. When chunk_spec changes (rectilinear grids, edge chunks), the spec chain is recomputed. For codec chains with no AA codecs (the common case), the resolution short-circuits with no allocation. Benchmarks vs the prior fix: 0 regressions worse than 4% (within noise), up to 21% faster on simple cases. GPU prototype now flows through correctly without an opt-in parameter.
1 parent 74351e7 commit 82b9e24

4 files changed

Lines changed: 92 additions & 159 deletions

File tree

src/zarr/codecs/sharding.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -429,34 +429,40 @@ def validate(
429429
)
430430

431431
def _get_inner_chunk_transform(self, shard_spec: ArraySpec) -> Any:
432-
"""Build a ChunkTransform for inner codecs, bound to the inner chunk spec."""
432+
"""Build a ChunkTransform for the inner codec chain.
433+
434+
The cache key is the shard_spec because evolved codecs may
435+
depend on it. The runtime chunk_spec is supplied per call.
436+
"""
433437
from zarr.core.codec_pipeline import ChunkTransform
434438

435439
chunk_spec = self._get_chunk_spec(shard_spec)
436440
evolved = tuple(c.evolve_from_array_spec(array_spec=chunk_spec) for c in self.codecs)
437-
return ChunkTransform(codecs=evolved, array_spec=chunk_spec)
441+
return ChunkTransform(codecs=evolved)
438442

439443
def _get_index_chunk_transform(self, chunks_per_shard: tuple[int, ...]) -> Any:
440-
"""Build a ChunkTransform for index codecs."""
444+
"""Build a ChunkTransform for the index codec chain."""
441445
from zarr.core.codec_pipeline import ChunkTransform
442446

443447
index_spec = self._get_index_chunk_spec(chunks_per_shard)
444448
evolved = tuple(c.evolve_from_array_spec(array_spec=index_spec) for c in self.index_codecs)
445-
return ChunkTransform(codecs=evolved, array_spec=index_spec)
449+
return ChunkTransform(codecs=evolved)
446450

447451
def _decode_shard_index_sync(
448452
self, index_bytes: Buffer, chunks_per_shard: tuple[int, ...]
449453
) -> _ShardIndex:
450454
"""Decode shard index synchronously using ChunkTransform."""
451455
index_transform = self._get_index_chunk_transform(chunks_per_shard)
452-
index_array = index_transform.decode_chunk(index_bytes)
456+
index_spec = self._get_index_chunk_spec(chunks_per_shard)
457+
index_array = index_transform.decode_chunk(index_bytes, index_spec)
453458
return _ShardIndex(index_array.as_numpy_array())
454459

455460
def _encode_shard_index_sync(self, index: _ShardIndex) -> Buffer:
456461
"""Encode shard index synchronously using ChunkTransform."""
457462
index_transform = self._get_index_chunk_transform(index.chunks_per_shard)
463+
index_spec = self._get_index_chunk_spec(index.chunks_per_shard)
458464
index_nd = get_ndbuffer_class().from_numpy_array(index.offsets_and_lengths)
459-
result: Buffer | None = index_transform.encode_chunk(index_nd)
465+
result: Buffer | None = index_transform.encode_chunk(index_nd, index_spec)
460466
assert result is not None
461467
return result
462468

@@ -511,7 +517,7 @@ def _decode_sync(
511517
except KeyError:
512518
out[out_selection] = shard_spec.fill_value
513519
continue
514-
chunk_array = inner_transform.decode_chunk(chunk_bytes)
520+
chunk_array = inner_transform.decode_chunk(chunk_bytes, chunk_spec)
515521
out[out_selection] = chunk_array[chunk_selection]
516522

517523
return out
@@ -524,6 +530,7 @@ def _encode_sync(
524530
"""Encode a full shard synchronously."""
525531
shard_shape = shard_spec.shape
526532
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
533+
chunk_spec = self._get_chunk_spec(shard_spec)
527534
inner_transform = self._get_inner_chunk_transform(shard_spec)
528535

529536
indexer = BasicIndexer(
@@ -546,7 +553,7 @@ def _encode_sync(
546553
if skip_empty and chunk_array.all_equal(fill_value):
547554
shard_builder[chunk_coords] = None
548555
else:
549-
encoded = inner_transform.encode_chunk(chunk_array)
556+
encoded = inner_transform.encode_chunk(chunk_array, chunk_spec)
550557
shard_builder[chunk_coords] = encoded
551558

552559
return self._encode_shard_dict_sync(
@@ -636,10 +643,12 @@ def _byte_offset(coords: tuple[int, ...]) -> int:
636643
existing_chunk_bytes = existing[
637644
byte_offset : byte_offset + chunk_byte_length
638645
]
639-
chunk_array = inner_transform.decode_chunk(existing_chunk_bytes).copy()
646+
chunk_array = inner_transform.decode_chunk(
647+
existing_chunk_bytes, chunk_spec
648+
).copy()
640649
chunk_array[chunk_sel] = chunk_value
641650

642-
encoded = inner_transform.encode_chunk(chunk_array)
651+
encoded = inner_transform.encode_chunk(chunk_array, chunk_spec)
643652
if encoded is not None:
644653
store.set_range_sync(key, encoded, byte_offset)
645654
index.set_chunk_slice(
@@ -685,7 +694,7 @@ def _byte_offset(coords: tuple[int, ...]) -> int:
685694
else:
686695
existing_raw = shard_dict.get(chunk_coords)
687696
if existing_raw is not None:
688-
chunk_array = inner_transform.decode_chunk(existing_raw).copy()
697+
chunk_array = inner_transform.decode_chunk(existing_raw, chunk_spec).copy()
689698
else:
690699
chunk_array = chunk_spec.prototype.nd_buffer.create(
691700
shape=self.chunk_shape,
@@ -698,7 +707,7 @@ def _byte_offset(coords: tuple[int, ...]) -> int:
698707
if skip_empty and chunk_array.all_equal(fill_value):
699708
shard_dict[chunk_coords] = None
700709
else:
701-
shard_dict[chunk_coords] = inner_transform.encode_chunk(chunk_array)
710+
shard_dict[chunk_coords] = inner_transform.encode_chunk(chunk_array, chunk_spec)
702711

703712
blob = self._encode_shard_dict_sync(
704713
shard_dict,

src/zarr/core/codec_pipeline.py

Lines changed: 60 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import threading
55
from concurrent.futures import ThreadPoolExecutor
6-
from dataclasses import dataclass, field, replace
6+
from dataclasses import dataclass, field
77
from itertools import islice, pairwise
88
from typing import TYPE_CHECKING, Any
99
from warnings import warn
@@ -87,24 +87,23 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any:
8787

8888
@dataclass(slots=True, kw_only=True)
8989
class ChunkTransform:
90-
"""A synchronous codec chain bound to an ArraySpec.
90+
"""A synchronous codec chain.
9191
92-
Provides `encode` and `decode` for pure-compute codec operations
93-
(no IO, no threading, no batching).
92+
Provides `encode_chunk` and `decode_chunk` for pure-compute codec
93+
operations (no IO, no threading, no batching). The `chunk_spec` is
94+
supplied per call so the same transform can be reused across chunks
95+
with different shapes, prototypes, etc.
9496
9597
All codecs must implement `SupportsSyncCodec`. Construction will
9698
raise `TypeError` if any codec does not.
9799
"""
98100

99101
codecs: tuple[Codec, ...]
100-
array_spec: ArraySpec
101102

102-
# (sync codec, input_spec) pairs in pipeline order.
103-
_aa_codecs: tuple[tuple[SupportsSyncCodec[NDBuffer, NDBuffer], ArraySpec], ...] = field(
103+
_aa_codecs: tuple[SupportsSyncCodec[NDBuffer, NDBuffer], ...] = field(
104104
init=False, repr=False, compare=False
105105
)
106106
_ab_codec: SupportsSyncCodec[NDBuffer, Buffer] = field(init=False, repr=False, compare=False)
107-
_ab_spec: ArraySpec = field(init=False, repr=False, compare=False)
108107
_bb_codecs: tuple[SupportsSyncCodec[Buffer, Buffer], ...] = field(
109108
init=False, repr=False, compare=False
110109
)
@@ -118,131 +117,78 @@ def __post_init__(self) -> None:
118117
)
119118

120119
aa, ab, bb = codecs_from_list(list(self.codecs))
120+
for c in (*aa, ab, *bb):
121+
assert isinstance(c, SupportsSyncCodec)
122+
self._aa_codecs = tuple(aa) # type: ignore[assignment]
123+
self._ab_codec = ab # type: ignore[assignment]
124+
self._bb_codecs = tuple(bb) # type: ignore[assignment]
125+
126+
_cached_key: tuple[tuple[int, ...], int] | None = field(
127+
init=False, repr=False, compare=False, default=None
128+
)
129+
_cached_aa_specs: tuple[ArraySpec, ...] | None = field(
130+
init=False, repr=False, compare=False, default=None
131+
)
132+
_cached_ab_spec: ArraySpec | None = field(
133+
init=False, repr=False, compare=False, default=None
134+
)
121135

122-
aa_codecs: list[tuple[SupportsSyncCodec[NDBuffer, NDBuffer], ArraySpec]] = []
123-
spec = self.array_spec
124-
for aa_codec in aa:
125-
assert isinstance(aa_codec, SupportsSyncCodec)
126-
aa_codecs.append((aa_codec, spec))
127-
spec = aa_codec.resolve_metadata(spec)
128-
129-
self._aa_codecs = tuple(aa_codecs)
130-
assert isinstance(ab, SupportsSyncCodec)
131-
self._ab_codec = ab
132-
self._ab_spec = spec
133-
bb_sync: list[SupportsSyncCodec[Buffer, Buffer]] = []
134-
for bb_codec in bb:
135-
assert isinstance(bb_codec, SupportsSyncCodec)
136-
bb_sync.append(bb_codec)
137-
self._bb_codecs = tuple(bb_sync)
138-
139-
def _spec_for_shape(
140-
self, shape: tuple[int, ...], prototype: BufferPrototype | None = None
141-
) -> ArraySpec:
142-
"""Build an ArraySpec with the given shape (and optional prototype)."""
143-
if shape == self._ab_spec.shape and (
144-
prototype is None or prototype is self._ab_spec.prototype
145-
):
146-
return self._ab_spec
147-
if prototype is None:
148-
return replace(self._ab_spec, shape=shape)
149-
return replace(self._ab_spec, shape=shape, prototype=prototype)
136+
def _resolve_specs(self, chunk_spec: ArraySpec) -> tuple[tuple[ArraySpec, ...], ArraySpec]:
137+
"""Return per-AA-codec input specs and the AB spec for ``chunk_spec``.
150138
151-
def decode_chunk(
152-
self,
153-
chunk_bytes: Buffer,
154-
chunk_shape: tuple[int, ...] | None = None,
155-
prototype: BufferPrototype | None = None,
156-
) -> NDBuffer:
139+
The codec chain only changes ``shape`` (via TransposeCodec etc.) —
140+
``prototype``, ``dtype``, ``fill_value``, and ``config`` are
141+
invariant. We cache the resolved spec chain keyed on
142+
``(chunk_spec.shape, id(chunk_spec))``, and reuse it directly
143+
when the same ``chunk_spec`` is passed again. For a different
144+
``chunk_spec`` with the same shape, we recompute (cheap).
145+
"""
146+
if not self._aa_codecs:
147+
return (), chunk_spec
148+
key = (chunk_spec.shape, id(chunk_spec))
149+
if self._cached_key == key:
150+
assert self._cached_aa_specs is not None
151+
assert self._cached_ab_spec is not None
152+
return self._cached_aa_specs, self._cached_ab_spec
153+
154+
aa_specs: list[ArraySpec] = []
155+
spec = chunk_spec
156+
for aa_codec in self._aa_codecs:
157+
aa_specs.append(spec)
158+
spec = aa_codec.resolve_metadata(spec) # type: ignore[attr-defined]
159+
aa_specs_t = tuple(aa_specs)
160+
self._cached_key = key
161+
self._cached_aa_specs = aa_specs_t
162+
self._cached_ab_spec = spec
163+
return aa_specs_t, spec
164+
165+
def decode_chunk(self, chunk_bytes: Buffer, chunk_spec: ArraySpec) -> NDBuffer:
157166
"""Decode a single chunk through the full codec chain, synchronously.
158167
159168
Pure compute -- no IO.
160-
161-
Parameters
162-
----------
163-
chunk_bytes : Buffer
164-
The encoded chunk bytes.
165-
chunk_shape : tuple[int, ...] or None
166-
The shape of this chunk. If None, uses the shape from the
167-
ArraySpec provided at construction. Required for rectilinear
168-
grids where chunks have different shapes.
169-
prototype : BufferPrototype or None
170-
The buffer prototype for the output. If None, uses the
171-
prototype from the ArraySpec provided at construction.
172-
Required when decoding into a non-default buffer (e.g. GPU).
173169
"""
174-
if chunk_shape is None and (prototype is None or prototype is self._ab_spec.prototype):
175-
# Use pre-computed specs
176-
ab_spec = self._ab_spec
177-
aa_specs: list[ArraySpec] = [s for _, s in self._aa_codecs]
178-
else:
179-
# Resolve chunk_shape through the aa_codecs to get the correct
180-
# spec for the ab_codec (e.g., TransposeCodec changes the shape).
181-
base_spec = self._spec_for_shape(
182-
chunk_shape if chunk_shape is not None else self._ab_spec.shape,
183-
prototype=prototype,
184-
)
185-
aa_specs = []
186-
spec = base_spec
187-
for aa_codec, _ in self._aa_codecs:
188-
aa_specs.append(spec)
189-
spec = aa_codec.resolve_metadata(spec) # type: ignore[attr-defined]
190-
ab_spec = spec
170+
aa_specs, ab_spec = self._resolve_specs(chunk_spec)
191171

192172
data: Buffer = chunk_bytes
193173
for bb_codec in reversed(self._bb_codecs):
194174
data = bb_codec._decode_sync(data, ab_spec)
195175

196176
chunk_array: NDBuffer = self._ab_codec._decode_sync(data, ab_spec)
197177

198-
for (aa_codec, _), aa_spec in zip(
199-
reversed(self._aa_codecs), reversed(aa_specs), strict=True
200-
):
178+
for aa_codec, aa_spec in zip(reversed(self._aa_codecs), reversed(aa_specs), strict=True):
201179
chunk_array = aa_codec._decode_sync(chunk_array, aa_spec)
202180

203181
return chunk_array
204182

205-
def encode_chunk(
206-
self,
207-
chunk_array: NDBuffer,
208-
chunk_shape: tuple[int, ...] | None = None,
209-
prototype: BufferPrototype | None = None,
210-
) -> Buffer | None:
183+
def encode_chunk(self, chunk_array: NDBuffer, chunk_spec: ArraySpec) -> Buffer | None:
211184
"""Encode a single chunk through the full codec chain, synchronously.
212185
213186
Pure compute -- no IO.
214-
215-
Parameters
216-
----------
217-
chunk_array : NDBuffer
218-
The chunk data to encode.
219-
chunk_shape : tuple[int, ...] or None
220-
The shape of this chunk. If None, uses the shape from the
221-
ArraySpec provided at construction.
222-
prototype : BufferPrototype or None
223-
The buffer prototype to use for intermediate buffers. If
224-
None, uses the prototype from the ArraySpec provided at
225-
construction. Required when encoding non-default buffers
226-
(e.g. GPU) so the codec chain produces matching buffer
227-
types.
228187
"""
229-
if chunk_shape is None and (prototype is None or prototype is self._ab_spec.prototype):
230-
ab_spec = self._ab_spec
231-
aa_specs: list[ArraySpec] = [s for _, s in self._aa_codecs]
232-
else:
233-
base_spec = self._spec_for_shape(
234-
chunk_shape if chunk_shape is not None else self._ab_spec.shape,
235-
prototype=prototype,
236-
)
237-
aa_specs = []
238-
spec = base_spec
239-
for aa_codec, _ in self._aa_codecs:
240-
aa_specs.append(spec)
241-
spec = aa_codec.resolve_metadata(spec) # type: ignore[attr-defined]
242-
ab_spec = spec
188+
aa_specs, ab_spec = self._resolve_specs(chunk_spec)
243189

244190
aa_data: NDBuffer = chunk_array
245-
for (aa_codec, _), aa_spec in zip(self._aa_codecs, aa_specs, strict=True):
191+
for aa_codec, aa_spec in zip(self._aa_codecs, aa_specs, strict=True):
246192
aa_result = aa_codec._encode_sync(aa_data, aa_spec)
247193
if aa_result is None:
248194
return None
@@ -824,9 +770,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
824770
aa, ab, bb = codecs_from_list(evolved_codecs)
825771

826772
try:
827-
sync_transform: ChunkTransform | None = ChunkTransform(
828-
codecs=evolved_codecs, array_spec=array_spec
829-
)
773+
sync_transform: ChunkTransform | None = ChunkTransform(codecs=evolved_codecs)
830774
except TypeError:
831775
sync_transform = None
832776

@@ -984,15 +928,7 @@ def read_sync(
984928
def _decode_one(raw: Buffer | None, chunk_spec: ArraySpec) -> NDBuffer | None:
985929
if raw is None:
986930
return None
987-
chunk_shape = (
988-
chunk_spec.shape if chunk_spec.shape != transform.array_spec.shape else None
989-
)
990-
prototype = (
991-
chunk_spec.prototype
992-
if chunk_spec.prototype is not transform.array_spec.prototype
993-
else None
994-
)
995-
return transform.decode_chunk(raw, chunk_shape=chunk_shape, prototype=prototype)
931+
return transform.decode_chunk(raw, chunk_spec)
996932

997933
specs = [cs for _, cs, *_ in batch]
998934
if n_workers > 0 and len(batch) > 1:
@@ -1071,21 +1007,10 @@ def _process_one(
10711007
) -> Buffer | None:
10721008
_, chunk_spec, chunk_selection, out_selection, is_complete = batch[idx]
10731009
existing_bytes = existing_buffers[idx]
1074-
chunk_shape = (
1075-
chunk_spec.shape if chunk_spec.shape != transform.array_spec.shape else None
1076-
)
1077-
1078-
prototype = (
1079-
chunk_spec.prototype
1080-
if chunk_spec.prototype is not transform.array_spec.prototype
1081-
else None
1082-
)
10831010

10841011
existing_chunk_array: NDBuffer | None = None
10851012
if existing_bytes is not None:
1086-
existing_chunk_array = transform.decode_chunk(
1087-
existing_bytes, chunk_shape=chunk_shape, prototype=prototype
1088-
)
1013+
existing_chunk_array = transform.decode_chunk(existing_bytes, chunk_spec)
10891014

10901015
chunk_array = self._merge_chunk_array(
10911016
existing_chunk_array,
@@ -1103,7 +1028,7 @@ def _process_one(
11031028
):
11041029
return None
11051030

1106-
return transform.encode_chunk(chunk_array, chunk_shape=chunk_shape, prototype=prototype)
1031+
return transform.encode_chunk(chunk_array, chunk_spec)
11071032

11081033
indices = list(range(len(batch)))
11091034
if n_workers > 0 and len(batch) > 1:

0 commit comments

Comments
 (0)