Skip to content

Commit 417df78

Browse files
committed
Merge branch 'main' into ig/shard_order
2 parents 58e071c + b76e006 commit 417df78

3 files changed

Lines changed: 81 additions & 25 deletions

File tree

src/zarr/codecs/sharding.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import random
4-
from collections.abc import Iterable, Mapping, MutableMapping
4+
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
55
from dataclasses import dataclass, replace
66
from enum import Enum
77
from functools import lru_cache
@@ -46,7 +46,9 @@
4646
from zarr.core.dtype.npy.int import UInt64
4747
from zarr.core.indexing import (
4848
BasicIndexer,
49+
ChunkProjection,
4950
SelectorTuple,
51+
_morton_order,
5052
_morton_order_keys,
5153
c_order_iter,
5254
get_indexer,
@@ -543,7 +545,7 @@ async def _decode_partial_single(
543545
else:
544546
return out
545547

546-
def _subchunk_iter(self, chunks_per_shard: tuple[int, ...]) -> Iterable[tuple[int, ...]]:
548+
def _subchunk_order_iter(self, chunks_per_shard: tuple[int, ...]) -> Iterable[tuple[int, ...]]:
547549
match self.subchunk_write_order:
548550
case SubchunkWriteOrder.morton:
549551
subchunk_iter = morton_order_iter(chunks_per_shard)
@@ -557,6 +559,17 @@ def _subchunk_iter(self, chunks_per_shard: tuple[int, ...]) -> Iterable[tuple[in
557559
subchunk_iter = iter(subchunk_list)
558560
return subchunk_iter
559561

562+
def _subchunk_order_vectorized(self, chunks_per_shard: tuple[int, ...]) -> npt.NDArray[np.intp]:
563+
match self.subchunk_write_order:
564+
case SubchunkWriteOrder.morton:
565+
subchunk_order_vectorized = _morton_order(chunks_per_shard)
566+
case _:
567+
subchunk_order_vectorized = np.fromiter(
568+
self._subchunk_order_iter(chunks_per_shard),
569+
dtype=np.dtype((int, len(chunks_per_shard))),
570+
)
571+
return subchunk_order_vectorized
572+
560573
async def _encode_single(
561574
self,
562575
shard_array: NDBuffer,
@@ -574,7 +587,7 @@ async def _encode_single(
574587
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
575588
)
576589
)
577-
shard_builder = dict.fromkeys(self._subchunk_iter(chunks_per_shard))
590+
shard_builder = dict.fromkeys(self._subchunk_order_iter(chunks_per_shard))
578591

579592
await self.codec_pipeline.write(
580593
[
@@ -608,23 +621,26 @@ async def _encode_partial_single(
608621
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
609622
chunk_spec = self._get_chunk_spec(shard_spec)
610623

611-
shard_reader = await self._load_full_shard_maybe(
612-
byte_getter=byte_setter,
613-
prototype=chunk_spec.prototype,
614-
chunks_per_shard=chunks_per_shard,
615-
)
616-
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
617-
# Use vectorized lookup for better performance
618-
shard_dict = shard_reader.to_dict_vectorized(
619-
np.asarray(list(self._subchunk_iter(chunks_per_shard)))
620-
)
621-
622624
indexer = list(
623625
get_indexer(
624626
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
625627
)
626628
)
627629

630+
if self._is_complete_shard_write(indexer, chunks_per_shard):
631+
shard_dict = dict.fromkeys(self._subchunk_order_iter(chunks_per_shard))
632+
else:
633+
shard_reader = await self._load_full_shard_maybe(
634+
byte_getter=byte_setter,
635+
prototype=chunk_spec.prototype,
636+
chunks_per_shard=chunks_per_shard,
637+
)
638+
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
639+
# Use vectorized lookup for better performance
640+
shard_dict = shard_reader.to_dict_vectorized(
641+
self._subchunk_order_vectorized(chunks_per_shard)
642+
)
643+
628644
await self.codec_pipeline.write(
629645
[
630646
(
@@ -661,7 +677,7 @@ async def _encode_shard_dict(
661677

662678
template = buffer_prototype.buffer.create_zero_length()
663679
chunk_start = 0
664-
for chunk_coords in self._subchunk_iter(chunks_per_shard):
680+
for chunk_coords in self._subchunk_order_iter(chunks_per_shard):
665681
value = map.get(chunk_coords)
666682
if value is None:
667683
continue
@@ -697,6 +713,16 @@ def _is_total_shard(
697713
chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard)
698714
)
699715

716+
def _is_complete_shard_write(
717+
self,
718+
indexed_chunks: Sequence[ChunkProjection],
719+
chunks_per_shard: tuple[int, ...],
720+
) -> bool:
721+
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}
722+
return self._is_total_shard(all_chunk_coords, chunks_per_shard) and all(
723+
is_complete_chunk for *_, is_complete_chunk in indexed_chunks
724+
)
725+
700726
async def _decode_shard_index(
701727
self, index_bytes: Buffer, chunks_per_shard: tuple[int, ...]
702728
) -> _ShardIndex:

tests/test_array.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,9 +2259,34 @@ def test_create_array_with_data_num_gets(
22592259
data = zarr.zeros(shape, dtype="int64")
22602260

22612261
zarr.create_array(store, data=data, chunks=chunk_shape, shards=shard_shape, fill_value=-1) # type: ignore[arg-type]
2262-
# one get for the metadata and one per shard.
2263-
# Note: we don't actually need one get per shard, but this is the current behavior
2264-
assert store.counter["get"] == 1 + num_shards
2262+
# One get for the metadata; full-shard writes should not read shard payloads.
2263+
assert store.counter["get"] == 1
2264+
2265+
2266+
@pytest.mark.parametrize(
2267+
("selection", "expected_gets"),
2268+
[(slice(None), 0), (slice(1, 9), 1)],
2269+
)
2270+
def test_shard_write_num_gets(selection: slice, expected_gets: int) -> None:
2271+
"""
2272+
Test that partial-shard writes read the existing data and full-shard writes don't.
2273+
"""
2274+
store = LoggingStore(store=MemoryStore())
2275+
arr = zarr.create_array(
2276+
store,
2277+
shape=(10,),
2278+
chunks=(1,),
2279+
shards=(10,),
2280+
dtype="int64",
2281+
fill_value=-1,
2282+
)
2283+
arr[:] = 0
2284+
2285+
store.counter.clear()
2286+
2287+
arr[selection] = 1
2288+
2289+
assert store.counter["get"] == expected_gets
22652290

22662291

22672292
@pytest.mark.parametrize("config", [{}, {"write_empty_chunks": True}, {"order": "C"}])

tests/test_codecs/test_sharding.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,7 @@ def test_sharding_mixed_integer_list_indexing(store: Store) -> None:
564564
"subchunk_write_order",
565565
list(SubchunkWriteOrder),
566566
)
567-
async def test_encoded_subchunk_write_order(
568-
subchunk_write_order: SubchunkWriteOrder,
569-
) -> None:
567+
async def test_encoded_subchunk_write_order(subchunk_write_order: SubchunkWriteOrder) -> None:
570568
"""Subchunks must be physically laid out in the shard in the order specified by
571569
``subchunk_write_order``. We verify this by decoding the shard index and sorting
572570
the chunk coordinates by their byte offset."""
@@ -612,7 +610,7 @@ async def test_encoded_subchunk_write_order(
612610

613611
# The physical write order is recovered by sorting coordinates by start offset.
614612
actual_order = [coord for _, coord in sorted(offset_to_coord.items())]
615-
expected_order = list(codec._subchunk_iter(chunks_per_shard))
613+
expected_order = list(codec._subchunk_order_iter(chunks_per_shard))
616614
assert (actual_order == expected_order) == (
617615
subchunk_write_order != SubchunkWriteOrder.unordered
618616
)
@@ -622,13 +620,15 @@ async def test_encoded_subchunk_write_order(
622620
"subchunk_write_order",
623621
list(SubchunkWriteOrder),
624622
)
625-
def test_subchunk_write_order_roundtrip(subchunk_write_order: SubchunkWriteOrder) -> None:
623+
@pytest.mark.parametrize("do_partial", [True, False], ids=["partial", "complete"])
624+
def test_subchunk_write_order_roundtrip(
625+
subchunk_write_order: SubchunkWriteOrder, do_partial: bool
626+
) -> None:
626627
"""Data written with any ``subchunk_write_order`` must round-trip correctly."""
627628
chunks_per_shard = (3, 2)
628629
chunk_shape = (4, 4)
629630
shard_shape = tuple(c * s for c, s in zip(chunks_per_shard, chunk_shape, strict=True))
630631
data = np.arange(np.prod(shard_shape), dtype="uint16").reshape(shard_shape)
631-
632632
arr = zarr.create_array(
633633
StorePath(MemoryStore()),
634634
shape=shard_shape,
@@ -643,5 +643,10 @@ def test_subchunk_write_order_roundtrip(subchunk_write_order: SubchunkWriteOrder
643643
compressors=None,
644644
fill_value=0,
645645
)
646-
arr[:] = data
646+
if do_partial:
647+
sub_data = data[: (shard_shape[0] // 2)]
648+
arr[: (shard_shape[0] // 2)] = data[: (shard_shape[0] // 2)]
649+
data = np.vstack([sub_data, np.zeros_like(sub_data)])
650+
else:
651+
arr[:] = data
647652
np.testing.assert_array_equal(arr[:], data)

0 commit comments

Comments
 (0)