Skip to content

Commit b2d9803

Browse files
committed
ensure no get on full shard writes
1 parent 9a71d59 commit b2d9803

2 files changed

Lines changed: 72 additions & 13 deletions

File tree

src/zarr/codecs/sharding.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterable, Mapping, MutableMapping
3+
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
44
from dataclasses import dataclass, replace
55
from enum import Enum
66
from functools import lru_cache
@@ -45,6 +45,7 @@
4545
from zarr.core.dtype.npy.int import UInt64
4646
from zarr.core.indexing import (
4747
BasicIndexer,
48+
ChunkProjection,
4849
SelectorTuple,
4950
_morton_order,
5051
_morton_order_keys,
@@ -574,21 +575,26 @@ async def _encode_partial_single(
574575
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
575576
chunk_spec = self._get_chunk_spec(shard_spec)
576577

577-
shard_reader = await self._load_full_shard_maybe(
578-
byte_getter=byte_setter,
579-
prototype=chunk_spec.prototype,
580-
chunks_per_shard=chunks_per_shard,
581-
)
582-
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
583-
# Use vectorized lookup for better performance
584-
shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard)))
585-
586578
indexer = list(
587579
get_indexer(
588580
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
589581
)
590582
)
591583

584+
if self._is_complete_shard_write(indexer, chunks_per_shard):
585+
shard_dict = dict.fromkeys(morton_order_iter(chunks_per_shard))
586+
else:
587+
shard_reader = await self._load_full_shard_maybe(
588+
byte_getter=byte_setter,
589+
prototype=chunk_spec.prototype,
590+
chunks_per_shard=chunks_per_shard,
591+
)
592+
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
593+
# Use vectorized lookup for better performance
594+
shard_dict = shard_reader.to_dict_vectorized(
595+
np.asarray(_morton_order(chunks_per_shard))
596+
)
597+
592598
await self.codec_pipeline.write(
593599
[
594600
(
@@ -661,6 +667,16 @@ def _is_total_shard(
661667
chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard)
662668
)
663669

670+
def _is_complete_shard_write(
671+
self,
672+
indexed_chunks: Sequence[ChunkProjection],
673+
chunks_per_shard: tuple[int, ...],
674+
) -> bool:
675+
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}
676+
return self._is_total_shard(all_chunk_coords, chunks_per_shard) and all(
677+
is_complete_chunk for *_, is_complete_chunk in indexed_chunks
678+
)
679+
664680
async def _decode_shard_index(
665681
self, index_bytes: Buffer, chunks_per_shard: tuple[int, ...]
666682
) -> _ShardIndex:

tests/test_array.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,9 +2259,52 @@ def test_create_array_with_data_num_gets(
22592259
data = zarr.zeros(shape, dtype="int64")
22602260

22612261
zarr.create_array(store, data=data, chunks=chunk_shape, shards=shard_shape, fill_value=-1) # type: ignore[arg-type]
2262-
# one get for the metadata and one per shard.
2263-
# Note: we don't actually need one get per shard, but this is the current behavior
2264-
assert store.counter["get"] == 1 + num_shards
2262+
# One get for the metadata; full-shard writes should not read shard payloads.
2263+
assert store.counter["get"] == 1
2264+
2265+
2266+
def test_full_shard_write_num_gets() -> None:
2267+
"""
2268+
Test that overwriting a complete shard does not read the existing shard first.
2269+
"""
2270+
store = LoggingStore(store=MemoryStore())
2271+
arr = zarr.create_array(
2272+
store,
2273+
shape=(10,),
2274+
chunks=(1,),
2275+
shards=(10,),
2276+
dtype="int64",
2277+
fill_value=-1,
2278+
)
2279+
arr[:] = 0
2280+
2281+
store.counter.clear()
2282+
2283+
arr[:] = np.arange(10, dtype="int64")
2284+
2285+
assert store.counter["get"] == 0
2286+
2287+
2288+
def test_partial_shard_write_num_gets() -> None:
2289+
"""
2290+
Test that partial shard writes still read the existing shard to preserve untouched chunks.
2291+
"""
2292+
store = LoggingStore(store=MemoryStore())
2293+
arr = zarr.create_array(
2294+
store,
2295+
shape=(10,),
2296+
chunks=(1,),
2297+
shards=(10,),
2298+
dtype="int64",
2299+
fill_value=-1,
2300+
)
2301+
arr[:] = 0
2302+
2303+
store.counter.clear()
2304+
2305+
arr[1:9] = np.arange(8, dtype="int64")
2306+
2307+
assert store.counter["get"] == 1
22652308

22662309

22672310
@pytest.mark.parametrize("config", [{}, {"write_empty_chunks": True}, {"order": "C"}])

0 commit comments

Comments
 (0)