Skip to content

Commit 2a3b404

Browse files
committed
refactor is_complete_chunk usage, add chunkrequest
1 parent 4632322 commit 2a3b404

File tree

6 files changed

+133
-118
lines changed

6 files changed

+133
-118
lines changed

src/zarr/abc/codec.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from zarr.abc.store import ByteGetter, ByteSetter, Store
2020
from zarr.core.array_spec import ArraySpec
2121
from zarr.core.chunk_grids import ChunkGrid
22+
from zarr.core.codec_pipeline import ChunkRequest
2223
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
2324
from zarr.core.indexing import ChunkProjection, SelectorTuple
2425
from zarr.core.metadata import ArrayMetadata
@@ -751,7 +752,7 @@ async def encode(
751752
@abstractmethod
752753
async def read(
753754
self,
754-
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
755+
batch_info: Iterable[ChunkRequest],
755756
out: NDBuffer,
756757
drop_axes: tuple[int, ...] = (),
757758
) -> None:
@@ -760,12 +761,10 @@ async def read(
760761
761762
Parameters
762763
----------
763-
batch_info : Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple]]
764-
Ordered set of information about the chunks.
765-
The first slice selection determines which parts of the chunk will be fetched.
766-
The second slice selection determines where in the output array the chunk data will be written.
767-
The ByteGetter is used to fetch the necessary bytes.
768-
The chunk spec contains information about the construction of an array from the bytes.
764+
batch_info : Iterable[ChunkRequest]
765+
Ordered set of chunk requests. Each ``ChunkRequest`` carries the
766+
store path (``byte_setter``), the ``ArraySpec`` for that chunk,
767+
chunk and output selections, and whether the chunk is complete.
769768
770769
If the Store returns ``None`` for a chunk, then the chunk was not
771770
written and the implementation must set the values of that chunk (or
@@ -778,7 +777,7 @@ async def read(
778777
@abstractmethod
779778
async def write(
780779
self,
781-
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
780+
batch_info: Iterable[ChunkRequest],
782781
value: NDBuffer,
783782
drop_axes: tuple[int, ...] = (),
784783
) -> None:
@@ -788,12 +787,10 @@ async def write(
788787
789788
Parameters
790789
----------
791-
batch_info : Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]]
792-
Ordered set of information about the chunks.
793-
The first slice selection determines which parts of the chunk will be encoded.
794-
The second slice selection determines where in the value array the chunk data is located.
795-
The ByteSetter is used to fetch and write the necessary bytes.
796-
The chunk spec contains information about the chunk.
790+
batch_info : Iterable[ChunkRequest]
791+
Ordered set of chunk requests. Each ``ChunkRequest`` carries the
792+
store path (``byte_setter``), the ``ArraySpec`` for that chunk,
793+
chunk and output selections, and whether the chunk is complete.
797794
value : NDBuffer
798795
"""
799796
...

src/zarr/codecs/sharding.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ async def _decode_single(
393393

394394
transform = self._get_chunk_transform(chunk_spec)
395395
fill_value = fill_value_or_default(chunk_spec)
396-
for chunk_coords, chunk_selection, out_selection, _ in indexer:
396+
for chunk_coords, chunk_selection, out_selection, _is_complete in indexer:
397397
chunk_bytes = shard_dict.get(chunk_coords)
398398
if chunk_bytes is not None:
399399
chunk_array = await transform.decode_chunk_async(chunk_bytes)
@@ -461,7 +461,7 @@ async def _decode_partial_single(
461461
# decode chunks and write them into the output buffer
462462
transform = self._get_chunk_transform(chunk_spec)
463463
fill_value = fill_value_or_default(chunk_spec)
464-
for chunk_coords, chunk_selection, out_selection, _ in indexed_chunks:
464+
for chunk_coords, chunk_selection, out_selection, _is_complete in indexed_chunks:
465465
chunk_bytes = shard_dict.get(chunk_coords)
466466
if chunk_bytes is not None:
467467
chunk_array = await transform.decode_chunk_async(chunk_bytes)
@@ -541,14 +541,14 @@ async def _encode_partial_single(
541541
fill_value = fill_value_or_default(chunk_spec)
542542

543543
is_scalar = len(shard_array.shape) == 0
544-
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer:
544+
for chunk_coords, chunk_selection, out_selection, is_complete in indexer:
545545
value = shard_array if is_scalar else shard_array[out_selection]
546-
if is_complete_chunk and not is_scalar and value.shape == chunk_spec.shape:
546+
if is_complete and not is_scalar and value.shape == chunk_spec.shape:
547547
# Complete overwrite with matching shape — use value directly
548548
chunk_data = value
549549
else:
550550
# Read-modify-write: decode existing or create new, merge data
551-
if is_complete_chunk:
551+
if is_complete:
552552
existing_bytes = None
553553
else:
554554
existing_bytes = shard_dict.get(chunk_coords)

src/zarr/core/array.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
V2ChunkKeyEncoding,
4949
parse_chunk_key_encoding,
5050
)
51+
from zarr.core.codec_pipeline import ChunkRequest
5152
from zarr.core.common import (
5253
JSON,
5354
ZARR_JSON,
@@ -5602,12 +5603,12 @@ async def _get_selection(
56025603
# reading chunks and decoding them
56035604
await codec_pipeline.read(
56045605
[
5605-
(
5606-
store_path / metadata.encode_chunk_key(chunk_coords),
5607-
metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype),
5608-
chunk_selection,
5609-
out_selection,
5610-
is_complete_chunk,
5606+
ChunkRequest(
5607+
byte_setter=store_path / metadata.encode_chunk_key(chunk_coords),
5608+
chunk_spec=metadata.get_chunk_spec(chunk_coords, _config, prototype=prototype),
5609+
chunk_selection=chunk_selection,
5610+
out_selection=out_selection,
5611+
is_complete_chunk=is_complete_chunk,
56115612
)
56125613
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
56135614
],
@@ -5912,12 +5913,12 @@ async def _set_selection(
59125913
# merging with existing data and encoding chunks
59135914
await codec_pipeline.write(
59145915
[
5915-
(
5916-
store_path / metadata.encode_chunk_key(chunk_coords),
5917-
metadata.get_chunk_spec(chunk_coords, _config, prototype),
5918-
chunk_selection,
5919-
out_selection,
5920-
is_complete_chunk,
5916+
ChunkRequest(
5917+
byte_setter=store_path / metadata.encode_chunk_key(chunk_coords),
5918+
chunk_spec=metadata.get_chunk_spec(chunk_coords, _config, prototype),
5919+
chunk_selection=chunk_selection,
5920+
out_selection=out_selection,
5921+
is_complete_chunk=is_complete_chunk,
59215922
)
59225923
for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer
59235924
],

src/zarr/core/codec_pipeline.py

Lines changed: 57 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,20 @@ def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
220220
return byte_length
221221

222222

223+
@dataclass(slots=True)
224+
class ChunkRequest:
225+
"""A single chunk's worth of metadata for a pipeline read or write.
226+
227+
Replaces the anonymous 5-tuples formerly threaded through ``batch_info``.
228+
"""
229+
230+
byte_setter: ByteSetter
231+
chunk_spec: ArraySpec
232+
chunk_selection: SelectorTuple
233+
out_selection: SelectorTuple
234+
is_complete_chunk: bool
235+
236+
223237
@dataclass(frozen=True)
224238
class BatchedCodecPipeline(CodecPipeline):
225239
"""Default codec pipeline.
@@ -400,48 +414,40 @@ async def encode_partial_batch(
400414

401415
async def read_batch(
402416
self,
403-
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
417+
batch_info: Iterable[ChunkRequest],
404418
out: NDBuffer,
405419
drop_axes: tuple[int, ...] = (),
406420
) -> None:
421+
batch_info = list(batch_info)
407422
if self.supports_partial_decode:
408423
chunk_array_batch = await self.decode_partial_batch(
409-
[
410-
(byte_getter, chunk_selection, chunk_spec)
411-
for byte_getter, chunk_spec, chunk_selection, *_ in batch_info
412-
]
424+
[(req.byte_setter, req.chunk_selection, req.chunk_spec) for req in batch_info]
413425
)
414-
for chunk_array, (_, chunk_spec, _, out_selection, _) in zip(
415-
chunk_array_batch, batch_info, strict=False
416-
):
426+
for chunk_array, req in zip(chunk_array_batch, batch_info, strict=False):
417427
if chunk_array is not None:
418-
out[out_selection] = chunk_array
428+
out[req.out_selection] = chunk_array
419429
else:
420-
out[out_selection] = fill_value_or_default(chunk_spec)
430+
out[req.out_selection] = fill_value_or_default(req.chunk_spec)
421431
else:
422432
chunk_bytes_batch = await concurrent_map(
423-
[(byte_getter, chunk_spec.prototype) for byte_getter, chunk_spec, *_ in batch_info],
433+
[(req.byte_setter, req.chunk_spec.prototype) for req in batch_info],
424434
lambda byte_getter, prototype: byte_getter.get(prototype),
425435
config.get("async.concurrency"),
426436
)
427437
chunk_array_batch = await self.decode_batch(
428438
[
429-
(chunk_bytes, chunk_spec)
430-
for chunk_bytes, (_, chunk_spec, *_) in zip(
431-
chunk_bytes_batch, batch_info, strict=False
432-
)
439+
(chunk_bytes, req.chunk_spec)
440+
for chunk_bytes, req in zip(chunk_bytes_batch, batch_info, strict=False)
433441
],
434442
)
435-
for chunk_array, (_, chunk_spec, chunk_selection, out_selection, _) in zip(
436-
chunk_array_batch, batch_info, strict=False
437-
):
443+
for chunk_array, req in zip(chunk_array_batch, batch_info, strict=False):
438444
if chunk_array is not None:
439-
tmp = chunk_array[chunk_selection]
445+
tmp = chunk_array[req.chunk_selection]
440446
if drop_axes != ():
441447
tmp = tmp.squeeze(axis=drop_axes)
442-
out[out_selection] = tmp
448+
out[req.out_selection] = tmp
443449
else:
444-
out[out_selection] = fill_value_or_default(chunk_spec)
450+
out[req.out_selection] = fill_value_or_default(req.chunk_spec)
445451

446452
def _merge_chunk_array(
447453
self,
@@ -450,13 +456,11 @@ def _merge_chunk_array(
450456
out_selection: SelectorTuple,
451457
chunk_spec: ArraySpec,
452458
chunk_selection: SelectorTuple,
453-
is_complete_chunk: bool,
454459
drop_axes: tuple[int, ...],
455460
) -> NDBuffer:
456461
if (
457-
is_complete_chunk
462+
existing_chunk_array is None
458463
and value.shape == chunk_spec.shape
459-
# Guard that this is not a partial chunk at the end with is_complete_chunk=True
460464
and value[out_selection].shape == chunk_spec.shape
461465
):
462466
return value
@@ -489,24 +493,30 @@ def _merge_chunk_array(
489493

490494
async def write_batch(
491495
self,
492-
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
496+
batch_info: Iterable[ChunkRequest],
493497
value: NDBuffer,
494498
drop_axes: tuple[int, ...] = (),
495499
) -> None:
500+
batch_info = list(batch_info)
496501
if self.supports_partial_encode:
497502
# Pass scalar values as is
498503
if len(value.shape) == 0:
499504
await self.encode_partial_batch(
500505
[
501-
(byte_setter, value, chunk_selection, chunk_spec)
502-
for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info
506+
(req.byte_setter, value, req.chunk_selection, req.chunk_spec)
507+
for req in batch_info
503508
],
504509
)
505510
else:
506511
await self.encode_partial_batch(
507512
[
508-
(byte_setter, value[out_selection], chunk_selection, chunk_spec)
509-
for byte_setter, chunk_spec, chunk_selection, out_selection, _ in batch_info
513+
(
514+
req.byte_setter,
515+
value[req.out_selection],
516+
req.chunk_selection,
517+
req.chunk_spec,
518+
)
519+
for req in batch_info
510520
],
511521
)
512522

@@ -523,61 +533,48 @@ async def _read_key(
523533
chunk_bytes_batch = await concurrent_map(
524534
[
525535
(
526-
None if is_complete_chunk else byte_setter,
527-
chunk_spec.prototype,
536+
None if req.is_complete_chunk else req.byte_setter,
537+
req.chunk_spec.prototype,
528538
)
529-
for byte_setter, chunk_spec, chunk_selection, _, is_complete_chunk in batch_info
539+
for req in batch_info
530540
],
531541
_read_key,
532542
config.get("async.concurrency"),
533543
)
534544
chunk_array_decoded = await self.decode_batch(
535545
[
536-
(chunk_bytes, chunk_spec)
537-
for chunk_bytes, (_, chunk_spec, *_) in zip(
538-
chunk_bytes_batch, batch_info, strict=False
539-
)
546+
(chunk_bytes, req.chunk_spec)
547+
for chunk_bytes, req in zip(chunk_bytes_batch, batch_info, strict=False)
540548
],
541549
)
542550

543551
chunk_array_merged = [
544552
self._merge_chunk_array(
545553
chunk_array,
546554
value,
547-
out_selection,
548-
chunk_spec,
549-
chunk_selection,
550-
is_complete_chunk,
555+
req.out_selection,
556+
req.chunk_spec,
557+
req.chunk_selection,
551558
drop_axes,
552559
)
553-
for chunk_array, (
554-
_,
555-
chunk_spec,
556-
chunk_selection,
557-
out_selection,
558-
is_complete_chunk,
559-
) in zip(chunk_array_decoded, batch_info, strict=False)
560+
for chunk_array, req in zip(chunk_array_decoded, batch_info, strict=False)
560561
]
561562
chunk_array_batch: list[NDBuffer | None] = []
562-
for chunk_array, (_, chunk_spec, *_) in zip(
563-
chunk_array_merged, batch_info, strict=False
564-
):
563+
for chunk_array, req in zip(chunk_array_merged, batch_info, strict=False):
565564
if chunk_array is None:
566565
chunk_array_batch.append(None) # type: ignore[unreachable]
567566
else:
568-
if not chunk_spec.config.write_empty_chunks and chunk_array.all_equal(
569-
fill_value_or_default(chunk_spec)
567+
if not req.chunk_spec.config.write_empty_chunks and chunk_array.all_equal(
568+
fill_value_or_default(req.chunk_spec)
570569
):
571570
chunk_array_batch.append(None)
572571
else:
573572
chunk_array_batch.append(chunk_array)
574573

575574
chunk_bytes_batch = await self.encode_batch(
576575
[
577-
(chunk_array, chunk_spec)
578-
for chunk_array, (_, chunk_spec, *_) in zip(
579-
chunk_array_batch, batch_info, strict=False
580-
)
576+
(chunk_array, req.chunk_spec)
577+
for chunk_array, req in zip(chunk_array_batch, batch_info, strict=False)
581578
],
582579
)
583580

@@ -589,10 +586,8 @@ async def _write_key(byte_setter: ByteSetter, chunk_bytes: Buffer | None) -> Non
589586

590587
await concurrent_map(
591588
[
592-
(byte_setter, chunk_bytes)
593-
for chunk_bytes, (byte_setter, *_) in zip(
594-
chunk_bytes_batch, batch_info, strict=False
595-
)
589+
(req.byte_setter, chunk_bytes)
590+
for chunk_bytes, req in zip(chunk_bytes_batch, batch_info, strict=False)
596591
],
597592
_write_key,
598593
config.get("async.concurrency"),
@@ -618,7 +613,7 @@ async def encode(
618613

619614
async def read(
620615
self,
621-
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
616+
batch_info: Iterable[ChunkRequest],
622617
out: NDBuffer,
623618
drop_axes: tuple[int, ...] = (),
624619
) -> None:
@@ -633,7 +628,7 @@ async def read(
633628

634629
async def write(
635630
self,
636-
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
631+
batch_info: Iterable[ChunkRequest],
637632
value: NDBuffer,
638633
drop_axes: tuple[int, ...] = (),
639634
) -> None:

0 commit comments

Comments
 (0)