Skip to content

Commit 3297e0d

Browse files
committed
execute performance improvement plan
1 parent 9071954 commit 3297e0d

16 files changed

Lines changed: 1814 additions & 847 deletions

File tree

src/zarr/abc/codec.py

Lines changed: 183 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from abc import abstractmethod
44
from collections.abc import Mapping
5-
from typing import TYPE_CHECKING, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable
5+
from dataclasses import dataclass
6+
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeGuard, TypeVar, runtime_checkable
67

78
from typing_extensions import ReadOnly, TypedDict
89

@@ -19,7 +20,7 @@
1920
from zarr.core.array_spec import ArraySpec
2021
from zarr.core.chunk_grids import ChunkGrid
2122
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
22-
from zarr.core.indexing import SelectorTuple
23+
from zarr.core.indexing import ChunkProjection, SelectorTuple
2324
from zarr.core.metadata import ArrayMetadata
2425

2526
__all__ = [
@@ -32,6 +33,7 @@
3233
"CodecInput",
3334
"CodecOutput",
3435
"CodecPipeline",
36+
"PreparedWrite",
3537
"SupportsSyncCodec",
3638
]
3739

@@ -200,9 +202,188 @@ class ArrayArrayCodec(BaseCodec[NDBuffer, NDBuffer]):
200202
"""Base class for array-to-array codecs."""
201203

202204

205+
def _is_complete_selection(selection: Any, shape: tuple[int, ...]) -> bool:
206+
"""Check whether a chunk selection covers the entire chunk shape."""
207+
if not isinstance(selection, tuple):
208+
selection = (selection,)
209+
for sel, dim_len in zip(selection, shape, strict=False):
210+
if isinstance(sel, int):
211+
if dim_len != 1:
212+
return False
213+
elif isinstance(sel, slice):
214+
start, stop, step = sel.indices(dim_len)
215+
if not (start == 0 and stop == dim_len and step == 1):
216+
return False
217+
else:
218+
return False
219+
return True
220+
221+
222+
@dataclass
223+
class PreparedWrite:
224+
"""Result of prepare_write: existing encoded chunk bytes + selection info."""
225+
226+
chunk_dict: dict[tuple[int, ...], Buffer | None]
227+
inner_codec_chain: Any # CodecChain
228+
inner_chunk_spec: ArraySpec
229+
indexer: list[ChunkProjection]
230+
value_selection: SelectorTuple | None = None
231+
# If not None, slice value with this before using inner out_selections.
232+
# For sharding: the outer out_selection from batch_info.
233+
# For non-sharded: None (inner out_selection IS the outer out_selection).
234+
write_full_shard: bool = True
235+
# True when the entire shard blob will be written from scratch (either
236+
# because the shard doesn't exist yet or because the selection is complete).
237+
# Used by ShardingCodec.finalize_write to decide between set vs set_range.
238+
is_complete_shard: bool = False
239+
# True when the outer selection covers the entire shard. When True,
240+
# the indexer is empty and finalize_write receives the shard value
241+
# via shard_data. The codec then encodes the full shard in one shot
242+
# rather than iterating over individual inner chunks.
243+
shard_data: NDBuffer | None = None
244+
# The full shard value for complete-selection writes. Set by the pipeline
245+
# when is_complete_shard is True, before calling finalize_write.
246+
247+
203248
class ArrayBytesCodec(BaseCodec[NDBuffer, Buffer]):
204249
"""Base class for array-to-bytes codecs."""
205250

251+
@property
252+
def inner_codec_chain(self) -> Any:
253+
"""The codec chain for decoding inner chunks after deserialization.
254+
255+
Returns None by default — the pipeline should use its own codec_chain.
256+
ShardingCodec overrides to return its inner codec chain.
257+
"""
258+
return None
259+
260+
def deserialize(
261+
self, raw: Buffer | None, chunk_spec: ArraySpec
262+
) -> dict[tuple[int, ...], Buffer | None]:
263+
"""Pure compute: unpack stored bytes into per-inner-chunk buffers.
264+
265+
Default implementation: single chunk at (0,).
266+
ShardingCodec overrides to decode shard index and slice blob into per-chunk buffers.
267+
"""
268+
return {(0,): raw}
269+
270+
def serialize(
271+
self, chunk_dict: dict[tuple[int, ...], Buffer | None], chunk_spec: ArraySpec
272+
) -> Buffer | None:
273+
"""Pure compute: pack per-inner-chunk buffers into a storage blob.
274+
275+
Default implementation: return the single chunk's bytes (or None if absent).
276+
ShardingCodec overrides to concatenate chunks + build index.
277+
Returns None if all chunks are empty (caller should delete the key).
278+
"""
279+
return chunk_dict.get((0,))
280+
281+
def prepare_read_sync(
282+
self,
283+
byte_getter: Any,
284+
chunk_spec: ArraySpec,
285+
chunk_selection: SelectorTuple,
286+
codec_chain: Any,
287+
aa_chain: Any,
288+
ab_pair: Any,
289+
bb_chain: Any,
290+
) -> NDBuffer | None:
291+
"""IO + full decode for the selected region. Returns decoded sub-array."""
292+
raw = byte_getter.get_sync(prototype=chunk_spec.prototype)
293+
chunk_array: NDBuffer | None = codec_chain.decode_chunk(
294+
raw, chunk_spec, aa_chain, ab_pair, bb_chain
295+
)
296+
if chunk_array is not None:
297+
return chunk_array[chunk_selection]
298+
return None
299+
300+
def prepare_write_sync(
301+
self,
302+
byte_setter: Any,
303+
chunk_spec: ArraySpec,
304+
chunk_selection: SelectorTuple,
305+
out_selection: SelectorTuple,
306+
codec_chain: Any,
307+
) -> PreparedWrite:
308+
"""IO + deserialize. Returns PreparedWrite for the pipeline to decode/merge/encode."""
309+
is_complete = _is_complete_selection(chunk_selection, chunk_spec.shape)
310+
existing: Buffer | None = None
311+
if not is_complete:
312+
existing = byte_setter.get_sync(prototype=chunk_spec.prototype)
313+
chunk_dict = self.deserialize(existing, chunk_spec)
314+
inner_chain = self.inner_codec_chain or codec_chain
315+
return PreparedWrite(
316+
chunk_dict=chunk_dict,
317+
inner_codec_chain=inner_chain,
318+
inner_chunk_spec=chunk_spec,
319+
indexer=[((0,), chunk_selection, out_selection, is_complete)], # type: ignore[list-item]
320+
)
321+
322+
async def prepare_read(
323+
self,
324+
byte_getter: Any,
325+
chunk_spec: ArraySpec,
326+
chunk_selection: SelectorTuple,
327+
codec_chain: Any,
328+
aa_chain: Any,
329+
ab_pair: Any,
330+
bb_chain: Any,
331+
) -> NDBuffer | None:
332+
"""Async IO + full decode for the selected region. Returns decoded sub-array."""
333+
raw = await byte_getter.get(prototype=chunk_spec.prototype)
334+
chunk_array: NDBuffer | None = codec_chain.decode_chunk(
335+
raw, chunk_spec, aa_chain, ab_pair, bb_chain
336+
)
337+
if chunk_array is not None:
338+
return chunk_array[chunk_selection]
339+
return None
340+
341+
async def prepare_write(
342+
self,
343+
byte_setter: Any,
344+
chunk_spec: ArraySpec,
345+
chunk_selection: SelectorTuple,
346+
out_selection: SelectorTuple,
347+
codec_chain: Any,
348+
) -> PreparedWrite:
349+
"""Async IO + deserialize. Returns PreparedWrite for the pipeline to decode/merge/encode."""
350+
is_complete = _is_complete_selection(chunk_selection, chunk_spec.shape)
351+
existing: Buffer | None = None
352+
if not is_complete:
353+
existing = await byte_setter.get(prototype=chunk_spec.prototype)
354+
chunk_dict = self.deserialize(existing, chunk_spec)
355+
inner_chain = self.inner_codec_chain or codec_chain
356+
return PreparedWrite(
357+
chunk_dict=chunk_dict,
358+
inner_codec_chain=inner_chain,
359+
inner_chunk_spec=chunk_spec,
360+
indexer=[((0,), chunk_selection, out_selection, is_complete)], # type: ignore[list-item]
361+
)
362+
363+
def finalize_write_sync(
364+
self, prepared: PreparedWrite, chunk_spec: ArraySpec, byte_setter: Any
365+
) -> None:
366+
"""Serialize prepared chunk_dict and write to store.
367+
368+
Default: serialize to a single blob and call set (or delete if all empty).
369+
ShardingCodec overrides this for byte-range writes when inner codecs are fixed-size.
370+
"""
371+
blob = self.serialize(prepared.chunk_dict, chunk_spec)
372+
if blob is None:
373+
byte_setter.delete_sync()
374+
else:
375+
byte_setter.set_sync(blob)
376+
377+
async def finalize_write(
378+
self, prepared: PreparedWrite, chunk_spec: ArraySpec, byte_setter: Any
379+
) -> None:
380+
"""Async version of finalize_write_sync."""
381+
blob = self.serialize(prepared.chunk_dict, chunk_spec)
382+
if blob is None:
383+
await byte_setter.delete()
384+
else:
385+
await byte_setter.set(blob)
386+
206387

207388
class BytesBytesCodec(BaseCodec[Buffer, Buffer]):
208389
"""Base class for bytes-to-bytes codecs."""

src/zarr/abc/store.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
"ByteGetter",
2121
"ByteSetter",
2222
"Store",
23-
"SyncByteGetter",
24-
"SyncByteSetter",
2523
"set_or_delete",
2624
]
2725

@@ -473,6 +471,21 @@ async def set(self, key: str, value: Buffer) -> None:
473471
"""
474472
...
475473

474+
async def set_range(self, key: str, value: Buffer, start: int) -> None:
475+
"""Write ``value`` into an existing key beginning at byte offset ``start``.
476+
477+
The key must already exist and ``start + len(value)`` must not exceed
478+
the current size of the stored value.
479+
480+
Parameters
481+
----------
482+
key : str
483+
value : Buffer
484+
start : int
485+
Byte offset at which to begin writing.
486+
"""
487+
raise NotImplementedError(f"{type(self).__name__} does not support set_range")
488+
476489
async def set_if_not_exists(self, key: str, value: Buffer) -> None:
477490
"""
478491
Store a key to ``value`` if the key is not already present.
@@ -702,29 +715,13 @@ async def get(
702715

703716
async def set(self, value: Buffer) -> None: ...
704717

718+
async def set_range(self, value: Buffer, start: int) -> None: ...
719+
705720
async def delete(self) -> None: ...
706721

707722
async def set_if_not_exists(self, default: Buffer) -> None: ...
708723

709724

710-
@runtime_checkable
711-
class SyncByteGetter(Protocol):
712-
"""Protocol for stores that support synchronous byte reads."""
713-
714-
def get_sync(
715-
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
716-
) -> Buffer | None: ...
717-
718-
719-
@runtime_checkable
720-
class SyncByteSetter(SyncByteGetter, Protocol):
721-
"""Protocol for stores that support synchronous byte reads, writes, and deletes."""
722-
723-
def set_sync(self, value: Buffer) -> None: ...
724-
725-
def delete_sync(self) -> None: ...
726-
727-
728725
async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None:
729726
"""Set or delete a value in a byte setter
730727

0 commit comments

Comments
 (0)