33import os
44import threading
55from concurrent .futures import ThreadPoolExecutor
6- from dataclasses import dataclass , field , replace
6+ from dataclasses import dataclass , field
77from itertools import islice , pairwise
88from typing import TYPE_CHECKING , Any
99from warnings import warn
@@ -87,24 +87,23 @@ def fill_value_or_default(chunk_spec: ArraySpec) -> Any:
8787
8888@dataclass (slots = True , kw_only = True )
8989class ChunkTransform :
90- """A synchronous codec chain bound to an ArraySpec .
90+ """A synchronous codec chain.
9191
92- Provides `encode` and `decode` for pure-compute codec operations
93- (no IO, no threading, no batching).
92+ Provides `encode_chunk` and `decode_chunk` for pure-compute codec
93+ operations (no IO, no threading, no batching). The `chunk_spec` is
94+ supplied per call so the same transform can be reused across chunks
95+ with different shapes, prototypes, etc.
9496
9597 All codecs must implement `SupportsSyncCodec`. Construction will
9698 raise `TypeError` if any codec does not.
9799 """
98100
99101 codecs : tuple [Codec , ...]
100- array_spec : ArraySpec
101102
102- # (sync codec, input_spec) pairs in pipeline order.
103- _aa_codecs : tuple [tuple [SupportsSyncCodec [NDBuffer , NDBuffer ], ArraySpec ], ...] = field (
103+ _aa_codecs : tuple [SupportsSyncCodec [NDBuffer , NDBuffer ], ...] = field (
104104 init = False , repr = False , compare = False
105105 )
106106 _ab_codec : SupportsSyncCodec [NDBuffer , Buffer ] = field (init = False , repr = False , compare = False )
107- _ab_spec : ArraySpec = field (init = False , repr = False , compare = False )
108107 _bb_codecs : tuple [SupportsSyncCodec [Buffer , Buffer ], ...] = field (
109108 init = False , repr = False , compare = False
110109 )
@@ -118,131 +117,78 @@ def __post_init__(self) -> None:
118117 )
119118
120119 aa , ab , bb = codecs_from_list (list (self .codecs ))
120+ for c in (* aa , ab , * bb ):
121+ assert isinstance (c , SupportsSyncCodec )
122+ self ._aa_codecs = tuple (aa ) # type: ignore[assignment]
123+ self ._ab_codec = ab # type: ignore[assignment]
124+ self ._bb_codecs = tuple (bb ) # type: ignore[assignment]
125+
126+ _cached_key : tuple [tuple [int , ...], int ] | None = field (
127+ init = False , repr = False , compare = False , default = None
128+ )
129+ _cached_aa_specs : tuple [ArraySpec , ...] | None = field (
130+ init = False , repr = False , compare = False , default = None
131+ )
132+ _cached_ab_spec : ArraySpec | None = field (
133+ init = False , repr = False , compare = False , default = None
134+ )
121135
122- aa_codecs : list [tuple [SupportsSyncCodec [NDBuffer , NDBuffer ], ArraySpec ]] = []
123- spec = self .array_spec
124- for aa_codec in aa :
125- assert isinstance (aa_codec , SupportsSyncCodec )
126- aa_codecs .append ((aa_codec , spec ))
127- spec = aa_codec .resolve_metadata (spec )
128-
129- self ._aa_codecs = tuple (aa_codecs )
130- assert isinstance (ab , SupportsSyncCodec )
131- self ._ab_codec = ab
132- self ._ab_spec = spec
133- bb_sync : list [SupportsSyncCodec [Buffer , Buffer ]] = []
134- for bb_codec in bb :
135- assert isinstance (bb_codec , SupportsSyncCodec )
136- bb_sync .append (bb_codec )
137- self ._bb_codecs = tuple (bb_sync )
138-
139- def _spec_for_shape (
140- self , shape : tuple [int , ...], prototype : BufferPrototype | None = None
141- ) -> ArraySpec :
142- """Build an ArraySpec with the given shape (and optional prototype)."""
143- if shape == self ._ab_spec .shape and (
144- prototype is None or prototype is self ._ab_spec .prototype
145- ):
146- return self ._ab_spec
147- if prototype is None :
148- return replace (self ._ab_spec , shape = shape )
149- return replace (self ._ab_spec , shape = shape , prototype = prototype )
136+ def _resolve_specs (self , chunk_spec : ArraySpec ) -> tuple [tuple [ArraySpec , ...], ArraySpec ]:
137+ """Return per-AA-codec input specs and the AB spec for ``chunk_spec``.
150138
151- def decode_chunk (
152- self ,
153- chunk_bytes : Buffer ,
154- chunk_shape : tuple [int , ...] | None = None ,
155- prototype : BufferPrototype | None = None ,
156- ) -> NDBuffer :
139+ The codec chain only changes ``shape`` (via TransposeCodec etc.) —
140+ ``prototype``, ``dtype``, ``fill_value``, and ``config`` are
141+ invariant. We cache the resolved spec chain keyed on
142+ ``(chunk_spec.shape, id(chunk_spec))``, and reuse it directly
143+ when the same ``chunk_spec`` is passed again. For a different
144+ ``chunk_spec`` with the same shape, we recompute (cheap).
145+ """
146+ if not self ._aa_codecs :
147+ return (), chunk_spec
148+ key = (chunk_spec .shape , id (chunk_spec ))
149+ if self ._cached_key == key :
150+ assert self ._cached_aa_specs is not None
151+ assert self ._cached_ab_spec is not None
152+ return self ._cached_aa_specs , self ._cached_ab_spec
153+
154+ aa_specs : list [ArraySpec ] = []
155+ spec = chunk_spec
156+ for aa_codec in self ._aa_codecs :
157+ aa_specs .append (spec )
158+ spec = aa_codec .resolve_metadata (spec ) # type: ignore[attr-defined]
159+ aa_specs_t = tuple (aa_specs )
160+ self ._cached_key = key
161+ self ._cached_aa_specs = aa_specs_t
162+ self ._cached_ab_spec = spec
163+ return aa_specs_t , spec
164+
165+ def decode_chunk (self , chunk_bytes : Buffer , chunk_spec : ArraySpec ) -> NDBuffer :
157166 """Decode a single chunk through the full codec chain, synchronously.
158167
159168 Pure compute -- no IO.
160-
161- Parameters
162- ----------
163- chunk_bytes : Buffer
164- The encoded chunk bytes.
165- chunk_shape : tuple[int, ...] or None
166- The shape of this chunk. If None, uses the shape from the
167- ArraySpec provided at construction. Required for rectilinear
168- grids where chunks have different shapes.
169- prototype : BufferPrototype or None
170- The buffer prototype for the output. If None, uses the
171- prototype from the ArraySpec provided at construction.
172- Required when decoding into a non-default buffer (e.g. GPU).
173169 """
174- if chunk_shape is None and (prototype is None or prototype is self ._ab_spec .prototype ):
175- # Use pre-computed specs
176- ab_spec = self ._ab_spec
177- aa_specs : list [ArraySpec ] = [s for _ , s in self ._aa_codecs ]
178- else :
179- # Resolve chunk_shape through the aa_codecs to get the correct
180- # spec for the ab_codec (e.g., TransposeCodec changes the shape).
181- base_spec = self ._spec_for_shape (
182- chunk_shape if chunk_shape is not None else self ._ab_spec .shape ,
183- prototype = prototype ,
184- )
185- aa_specs = []
186- spec = base_spec
187- for aa_codec , _ in self ._aa_codecs :
188- aa_specs .append (spec )
189- spec = aa_codec .resolve_metadata (spec ) # type: ignore[attr-defined]
190- ab_spec = spec
170+ aa_specs , ab_spec = self ._resolve_specs (chunk_spec )
191171
192172 data : Buffer = chunk_bytes
193173 for bb_codec in reversed (self ._bb_codecs ):
194174 data = bb_codec ._decode_sync (data , ab_spec )
195175
196176 chunk_array : NDBuffer = self ._ab_codec ._decode_sync (data , ab_spec )
197177
198- for (aa_codec , _ ), aa_spec in zip (
199- reversed (self ._aa_codecs ), reversed (aa_specs ), strict = True
200- ):
178+ for aa_codec , aa_spec in zip (reversed (self ._aa_codecs ), reversed (aa_specs ), strict = True ):
201179 chunk_array = aa_codec ._decode_sync (chunk_array , aa_spec )
202180
203181 return chunk_array
204182
205- def encode_chunk (
206- self ,
207- chunk_array : NDBuffer ,
208- chunk_shape : tuple [int , ...] | None = None ,
209- prototype : BufferPrototype | None = None ,
210- ) -> Buffer | None :
183+ def encode_chunk (self , chunk_array : NDBuffer , chunk_spec : ArraySpec ) -> Buffer | None :
211184 """Encode a single chunk through the full codec chain, synchronously.
212185
213186 Pure compute -- no IO.
214-
215- Parameters
216- ----------
217- chunk_array : NDBuffer
218- The chunk data to encode.
219- chunk_shape : tuple[int, ...] or None
220- The shape of this chunk. If None, uses the shape from the
221- ArraySpec provided at construction.
222- prototype : BufferPrototype or None
223- The buffer prototype to use for intermediate buffers. If
224- None, uses the prototype from the ArraySpec provided at
225- construction. Required when encoding non-default buffers
226- (e.g. GPU) so the codec chain produces matching buffer
227- types.
228187 """
229- if chunk_shape is None and (prototype is None or prototype is self ._ab_spec .prototype ):
230- ab_spec = self ._ab_spec
231- aa_specs : list [ArraySpec ] = [s for _ , s in self ._aa_codecs ]
232- else :
233- base_spec = self ._spec_for_shape (
234- chunk_shape if chunk_shape is not None else self ._ab_spec .shape ,
235- prototype = prototype ,
236- )
237- aa_specs = []
238- spec = base_spec
239- for aa_codec , _ in self ._aa_codecs :
240- aa_specs .append (spec )
241- spec = aa_codec .resolve_metadata (spec ) # type: ignore[attr-defined]
242- ab_spec = spec
188+ aa_specs , ab_spec = self ._resolve_specs (chunk_spec )
243189
244190 aa_data : NDBuffer = chunk_array
245- for ( aa_codec , _ ) , aa_spec in zip (self ._aa_codecs , aa_specs , strict = True ):
191+ for aa_codec , aa_spec in zip (self ._aa_codecs , aa_specs , strict = True ):
246192 aa_result = aa_codec ._encode_sync (aa_data , aa_spec )
247193 if aa_result is None :
248194 return None
@@ -824,9 +770,7 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
824770 aa , ab , bb = codecs_from_list (evolved_codecs )
825771
826772 try :
827- sync_transform : ChunkTransform | None = ChunkTransform (
828- codecs = evolved_codecs , array_spec = array_spec
829- )
773+ sync_transform : ChunkTransform | None = ChunkTransform (codecs = evolved_codecs )
830774 except TypeError :
831775 sync_transform = None
832776
@@ -984,15 +928,7 @@ def read_sync(
984928 def _decode_one (raw : Buffer | None , chunk_spec : ArraySpec ) -> NDBuffer | None :
985929 if raw is None :
986930 return None
987- chunk_shape = (
988- chunk_spec .shape if chunk_spec .shape != transform .array_spec .shape else None
989- )
990- prototype = (
991- chunk_spec .prototype
992- if chunk_spec .prototype is not transform .array_spec .prototype
993- else None
994- )
995- return transform .decode_chunk (raw , chunk_shape = chunk_shape , prototype = prototype )
931+ return transform .decode_chunk (raw , chunk_spec )
996932
997933 specs = [cs for _ , cs , * _ in batch ]
998934 if n_workers > 0 and len (batch ) > 1 :
@@ -1071,21 +1007,10 @@ def _process_one(
10711007 ) -> Buffer | None :
10721008 _ , chunk_spec , chunk_selection , out_selection , is_complete = batch [idx ]
10731009 existing_bytes = existing_buffers [idx ]
1074- chunk_shape = (
1075- chunk_spec .shape if chunk_spec .shape != transform .array_spec .shape else None
1076- )
1077-
1078- prototype = (
1079- chunk_spec .prototype
1080- if chunk_spec .prototype is not transform .array_spec .prototype
1081- else None
1082- )
10831010
10841011 existing_chunk_array : NDBuffer | None = None
10851012 if existing_bytes is not None :
1086- existing_chunk_array = transform .decode_chunk (
1087- existing_bytes , chunk_shape = chunk_shape , prototype = prototype
1088- )
1013+ existing_chunk_array = transform .decode_chunk (existing_bytes , chunk_spec )
10891014
10901015 chunk_array = self ._merge_chunk_array (
10911016 existing_chunk_array ,
@@ -1103,7 +1028,7 @@ def _process_one(
11031028 ):
11041029 return None
11051030
1106- return transform .encode_chunk (chunk_array , chunk_shape = chunk_shape , prototype = prototype )
1031+ return transform .encode_chunk (chunk_array , chunk_spec )
11071032
11081033 indices = list (range (len (batch )))
11091034 if n_workers > 0 and len (batch ) > 1 :
0 commit comments