Skip to content

Commit 715ec2a

Browse files
authored
Merge branch 'main' into ig/spec0_py314
2 parents 884ad70 + b76e006 commit 715ec2a

2 files changed

Lines changed: 54 additions & 13 deletions

File tree

src/zarr/codecs/sharding.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterable, Mapping, MutableMapping
3+
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
44
from dataclasses import dataclass, replace
55
from enum import Enum
66
from functools import lru_cache
@@ -45,6 +45,7 @@
4545
from zarr.core.dtype.npy.int import UInt64
4646
from zarr.core.indexing import (
4747
BasicIndexer,
48+
ChunkProjection,
4849
SelectorTuple,
4950
_morton_order,
5051
_morton_order_keys,
@@ -574,21 +575,26 @@ async def _encode_partial_single(
574575
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
575576
chunk_spec = self._get_chunk_spec(shard_spec)
576577

577-
shard_reader = await self._load_full_shard_maybe(
578-
byte_getter=byte_setter,
579-
prototype=chunk_spec.prototype,
580-
chunks_per_shard=chunks_per_shard,
581-
)
582-
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
583-
# Use vectorized lookup for better performance
584-
shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard)))
585-
586578
indexer = list(
587579
get_indexer(
588580
selection, shape=shard_shape, chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape)
589581
)
590582
)
591583

584+
if self._is_complete_shard_write(indexer, chunks_per_shard):
585+
shard_dict = dict.fromkeys(morton_order_iter(chunks_per_shard))
586+
else:
587+
shard_reader = await self._load_full_shard_maybe(
588+
byte_getter=byte_setter,
589+
prototype=chunk_spec.prototype,
590+
chunks_per_shard=chunks_per_shard,
591+
)
592+
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
593+
# Use vectorized lookup for better performance
594+
shard_dict = shard_reader.to_dict_vectorized(
595+
np.asarray(_morton_order(chunks_per_shard))
596+
)
597+
592598
await self.codec_pipeline.write(
593599
[
594600
(
@@ -661,6 +667,16 @@ def _is_total_shard(
661667
chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard)
662668
)
663669

670+
def _is_complete_shard_write(
671+
self,
672+
indexed_chunks: Sequence[ChunkProjection],
673+
chunks_per_shard: tuple[int, ...],
674+
) -> bool:
675+
all_chunk_coords = {chunk_coords for chunk_coords, *_ in indexed_chunks}
676+
return self._is_total_shard(all_chunk_coords, chunks_per_shard) and all(
677+
is_complete_chunk for *_, is_complete_chunk in indexed_chunks
678+
)
679+
664680
async def _decode_shard_index(
665681
self, index_bytes: Buffer, chunks_per_shard: tuple[int, ...]
666682
) -> _ShardIndex:

tests/test_array.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,9 +2259,34 @@ def test_create_array_with_data_num_gets(
22592259
data = zarr.zeros(shape, dtype="int64")
22602260

22612261
zarr.create_array(store, data=data, chunks=chunk_shape, shards=shard_shape, fill_value=-1) # type: ignore[arg-type]
2262-
# one get for the metadata and one per shard.
2263-
# Note: we don't actually need one get per shard, but this is the current behavior
2264-
assert store.counter["get"] == 1 + num_shards
2262+
# One get for the metadata; full-shard writes should not read shard payloads.
2263+
assert store.counter["get"] == 1
2264+
2265+
2266+
@pytest.mark.parametrize(
2267+
("selection", "expected_gets"),
2268+
[(slice(None), 0), (slice(1, 9), 1)],
2269+
)
2270+
def test_shard_write_num_gets(selection: slice, expected_gets: int) -> None:
2271+
"""
2272+
Test that partial-shard writes read the existing data and full-shard writes don't.
2273+
"""
2274+
store = LoggingStore(store=MemoryStore())
2275+
arr = zarr.create_array(
2276+
store,
2277+
shape=(10,),
2278+
chunks=(1,),
2279+
shards=(10,),
2280+
dtype="int64",
2281+
fill_value=-1,
2282+
)
2283+
arr[:] = 0
2284+
2285+
store.counter.clear()
2286+
2287+
arr[selection] = 1
2288+
2289+
assert store.counter["get"] == expected_gets
22652290

22662291

22672292
@pytest.mark.parametrize("config", [{}, {"write_empty_chunks": True}, {"order": "C"}])

0 commit comments

Comments
 (0)