Skip to content

Commit c65cf82

Browse files
committed
Fix and test for case where some chunks in shard are all fill
1 parent 009ce6a commit c65cf82

File tree

3 files changed

+102
-103
lines changed

3 files changed

+102
-103
lines changed

src/zarr/codecs/sharding.py

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ async def get(
9090
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
9191
) -> Buffer | None:
9292
assert byte_range is None, "byte_range is not supported within shards"
93-
assert (
94-
prototype == default_buffer_prototype()
95-
), f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}"
93+
assert prototype == default_buffer_prototype(), (
94+
f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}"
95+
)
9696
return self.shard_dict.get(self.chunk_coords)
9797

9898

@@ -124,9 +124,7 @@ def chunks_per_shard(self) -> ChunkCoords:
124124
def _localize_chunk(self, chunk_coords: ChunkCoords) -> ChunkCoords:
125125
return tuple(
126126
chunk_i % shard_i
127-
for chunk_i, shard_i in zip(
128-
chunk_coords, self.offsets_and_lengths.shape, strict=False
129-
)
127+
for chunk_i, shard_i in zip(chunk_coords, self.offsets_and_lengths.shape, strict=False)
130128
)
131129

132130
def is_all_empty(self) -> bool:
@@ -143,9 +141,7 @@ def get_chunk_slice(self, chunk_coords: ChunkCoords) -> tuple[int, int] | None:
143141
else:
144142
return (int(chunk_start), int(chunk_start + chunk_len))
145143

146-
def set_chunk_slice(
147-
self, chunk_coords: ChunkCoords, chunk_slice: slice | None
148-
) -> None:
144+
def set_chunk_slice(self, chunk_coords: ChunkCoords, chunk_slice: slice | None) -> None:
149145
localized_chunk = self._localize_chunk(chunk_coords)
150146
if chunk_slice is None:
151147
self.offsets_and_lengths[localized_chunk] = (MAX_UINT_64, MAX_UINT_64)
@@ -167,11 +163,7 @@ def is_dense(self, chunk_byte_length: int) -> bool:
167163

168164
# Are all non-empty offsets unique?
169165
if len(
170-
{
171-
offset
172-
for offset, _ in sorted_offsets_and_lengths
173-
if offset != MAX_UINT_64
174-
}
166+
{offset for offset, _ in sorted_offsets_and_lengths if offset != MAX_UINT_64}
175167
) != len(sorted_offsets_and_lengths):
176168
return False
177169

@@ -275,9 +267,7 @@ def __setitem__(self, chunk_coords: ChunkCoords, value: Buffer) -> None:
275267
chunk_start = len(self.buf)
276268
chunk_length = len(value)
277269
self.buf += value
278-
self.index.set_chunk_slice(
279-
chunk_coords, slice(chunk_start, chunk_start + chunk_length)
280-
)
270+
self.index.set_chunk_slice(chunk_coords, slice(chunk_start, chunk_start + chunk_length))
281271

282272
def __delitem__(self, chunk_coords: ChunkCoords) -> None:
283273
raise NotImplementedError
@@ -291,9 +281,7 @@ async def finalize(
291281
if index_location == ShardingCodecIndexLocation.start:
292282
empty_chunks_mask = self.index.offsets_and_lengths[..., 0] == MAX_UINT_64
293283
self.index.offsets_and_lengths[~empty_chunks_mask, 0] += len(index_bytes)
294-
index_bytes = await index_encoder(
295-
self.index
296-
) # encode again with corrected offsets
284+
index_bytes = await index_encoder(self.index) # encode again with corrected offsets
297285
out_buf = index_bytes + self.buf
298286
else:
299287
out_buf = self.buf + index_bytes
@@ -371,8 +359,7 @@ def __init__(
371359
chunk_shape: ChunkCoordsLike,
372360
codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(),),
373361
index_codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(), Crc32cCodec()),
374-
index_location: ShardingCodecIndexLocation
375-
| str = ShardingCodecIndexLocation.end,
362+
index_location: ShardingCodecIndexLocation | str = ShardingCodecIndexLocation.end,
376363
) -> None:
377364
chunk_shape_parsed = parse_shapelike(chunk_shape)
378365
codecs_parsed = parse_codecs(codecs)
@@ -402,9 +389,7 @@ def __setstate__(self, state: dict[str, Any]) -> None:
402389
object.__setattr__(self, "chunk_shape", parse_shapelike(config["chunk_shape"]))
403390
object.__setattr__(self, "codecs", parse_codecs(config["codecs"]))
404391
object.__setattr__(self, "index_codecs", parse_codecs(config["index_codecs"]))
405-
object.__setattr__(
406-
self, "index_location", parse_index_location(config["index_location"])
407-
)
392+
object.__setattr__(self, "index_location", parse_index_location(config["index_location"]))
408393

409394
# Use instance-local lru_cache to avoid memory leaks
410395
# object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
@@ -433,9 +418,7 @@ def to_dict(self) -> dict[str, JSON]:
433418

434419
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
435420
shard_spec = self._get_chunk_spec(array_spec)
436-
evolved_codecs = tuple(
437-
c.evolve_from_array_spec(array_spec=shard_spec) for c in self.codecs
438-
)
421+
evolved_codecs = tuple(c.evolve_from_array_spec(array_spec=shard_spec) for c in self.codecs)
439422
if evolved_codecs != self.codecs:
440423
return replace(self, codecs=evolved_codecs)
441424
return self
@@ -610,9 +593,7 @@ async def _encode_single(
610593
shard_array,
611594
)
612595

613-
return await shard_builder.finalize(
614-
self.index_location, self._encode_shard_index
615-
)
596+
return await shard_builder.finalize(self.index_location, self._encode_shard_index)
616597

617598
async def _encode_partial_single(
618599
self,
@@ -672,8 +653,7 @@ def _is_total_shard(
672653
self, all_chunk_coords: set[ChunkCoords], chunks_per_shard: ChunkCoords
673654
) -> bool:
674655
return len(all_chunk_coords) == product(chunks_per_shard) and all(
675-
chunk_coords in all_chunk_coords
676-
for chunk_coords in c_order_iter(chunks_per_shard)
656+
chunk_coords in all_chunk_coords for chunk_coords in c_order_iter(chunks_per_shard)
677657
)
678658

679659
async def _decode_shard_index(
@@ -699,9 +679,7 @@ async def _encode_shard_index(self, index: _ShardIndex) -> Buffer:
699679
.encode(
700680
[
701681
(
702-
get_ndbuffer_class().from_numpy_array(
703-
index.offsets_and_lengths
704-
),
682+
get_ndbuffer_class().from_numpy_array(index.offsets_and_lengths),
705683
self._get_index_chunk_spec(index.chunks_per_shard),
706684
)
707685
],
@@ -810,9 +788,10 @@ async def _load_partial_shard_maybe(
810788
_ChunkCoordsByteSlice(chunk_coords, slice(*chunk_byte_slice))
811789
for chunk_coords in all_chunk_coords
812790
# Drop chunks where index lookup fails
791+
# e.g. when write_empty_chunks = False and the chunk is empty
813792
if (chunk_byte_slice := shard_index.get_chunk_slice(chunk_coords))
814793
]
815-
if len(chunks) < len(all_chunk_coords):
794+
if len(chunks) == 0:
816795
return None
817796

818797
groups = self._coalesce_chunks(chunks)
@@ -854,9 +833,7 @@ def _coalesce_chunks(
854833

855834
for chunk in sorted_chunks[1:]:
856835
gap_to_chunk = chunk.byte_slice.start - current_group[-1].byte_slice.stop
857-
size_if_coalesced = (
858-
chunk.byte_slice.stop - current_group[0].byte_slice.start
859-
)
836+
size_if_coalesced = chunk.byte_slice.stop - current_group[0].byte_slice.start
860837
if gap_to_chunk < max_gap_bytes and size_if_coalesced < coalesce_max_bytes:
861838
current_group.append(chunk)
862839
else:
@@ -899,9 +876,7 @@ async def _get_group_bytes(
899876

900877
return shard_dict
901878

902-
def compute_encoded_size(
903-
self, input_byte_length: int, shard_spec: ArraySpec
904-
) -> int:
879+
def compute_encoded_size(self, input_byte_length: int, shard_spec: ArraySpec) -> int:
905880
chunks_per_shard = self._get_chunks_per_shard(shard_spec)
906881
return input_byte_length + self._shard_index_size(chunks_per_shard)
907882

tests/test_codecs/test_sharding.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,79 @@ def test_sharding_multiple_chunks_partial_shard_read(
344344
assert isinstance(kwargs["byte_range"], (SuffixByteRequest, RangeByteRequest))
345345

346346

347+
@pytest.mark.parametrize("index_location", ["start", "end"])
348+
@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
349+
def test_sharding_read_empty_chunks_within_non_empty_shard_write_empty_false(
350+
store: Store, index_location: ShardingCodecIndexLocation
351+
) -> None:
352+
"""
353+
Case where
354+
- some, but not all, chunks in the last shard are empty
355+
- the last shard is not complete (array length is not a multiple of shard shape),
356+
this takes us down the partial shard read path
357+
- write_empty_chunks=False so the shard index will have less entries than chunks in the shard
358+
"""
359+
# array with mixed empty and non-empty chunks in second shard
360+
data = np.array([
361+
# shard 0. full 8 elements, all chunks have some non-fill data
362+
0, 1, 2, 3, 4, 5, 6, 7,
363+
# shard 1. 6 elements (< shard shape)
364+
2, 0, # chunk 0, written
365+
0, 0, # chunk 1, all fill, not written
366+
4, 5 # chunk 2, written
367+
], dtype="int32") # fmt: off
368+
369+
spath = StorePath(store)
370+
a = zarr.create_array(
371+
spath,
372+
shape=(14,),
373+
chunks=(2,),
374+
shards={"shape": (8,), "index_location": index_location},
375+
dtype="int32",
376+
fill_value=0,
377+
filters=None,
378+
compressors=None,
379+
config={"write_empty_chunks": False},
380+
)
381+
a[:] = data
382+
383+
assert np.array_equal(a[:], data)
384+
385+
386+
@pytest.mark.parametrize("index_location", ["start", "end"])
387+
@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
388+
def test_sharding_read_empty_chunks_within_empty_shard_write_empty_false(
389+
store: Store, index_location: ShardingCodecIndexLocation
390+
) -> None:
391+
"""
392+
Case where
393+
- all chunks in last shard are empty
394+
- the last shard is not complete (array length is not a multiple of shard shape),
395+
this takes us down the partial shard read path
396+
- write_empty_chunks=False so the shard index will have no entries
397+
"""
398+
fill_value = -99
399+
shard_size = 8
400+
data = np.arange(14, dtype="int32")
401+
data[shard_size:] = fill_value # 2nd shard is all fill value
402+
403+
spath = StorePath(store)
404+
a = zarr.create_array(
405+
spath,
406+
shape=(14,),
407+
chunks=(2,),
408+
shards={"shape": (shard_size,), "index_location": index_location},
409+
dtype="int32",
410+
fill_value=fill_value,
411+
filters=None,
412+
compressors=None,
413+
config={"write_empty_chunks": False},
414+
)
415+
a[:] = data
416+
417+
assert np.array_equal(a[:], data)
418+
419+
347420
@pytest.mark.parametrize("index_location", ["start", "end"])
348421
@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
349422
def test_sharding_partial_shard_read__index_load_fails(
@@ -577,7 +650,6 @@ def test_nested_sharding_create_array(
577650
filters=None,
578651
compressors=None,
579652
)
580-
print(a.metadata.to_dict())
581653

582654
a[:, :, :] = data
583655

@@ -637,7 +709,6 @@ async def test_delete_empty_shards(store: Store) -> None:
637709
compressors=None,
638710
fill_value=1,
639711
)
640-
print(a.metadata.to_dict())
641712
await _AsyncArrayProxy(a)[:, :].set(np.zeros((16, 16)))
642713
await _AsyncArrayProxy(a)[8:, :].set(np.ones((8, 16)))
643714
await _AsyncArrayProxy(a)[:, 8:].set(np.ones((16, 8)))
@@ -682,7 +753,6 @@ async def test_sharding_with_empty_inner_chunk(
682753
)
683754
data[:4, :4] = fill_value
684755
await a.setitem(..., data)
685-
print("read data")
686756
data_read = await a.getitem(...)
687757
assert np.array_equal(data_read, data)
688758

tests/test_properties.py

Lines changed: 11 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,11 @@ def deep_equal(a: Any, b: Any) -> bool:
7676

7777

7878
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
79-
@given(data=st.data(), zarr_format=zarr_formats)
80-
def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None:
81-
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
82-
zarray = data.draw(
83-
arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format))
84-
)
85-
try:
86-
assert_array_equal(nparray, zarray[:])
87-
except Exception as e:
88-
breakpoint()
89-
raise e
79+
@given(data=st.data())
80+
def test_array_roundtrip(data: st.DataObject) -> None:
81+
nparray = data.draw(numpy_arrays())
82+
zarray = data.draw(arrays(arrays=st.just(nparray)))
83+
assert_array_equal(nparray, zarray[:])
9084

9185

9286
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@@ -98,20 +92,12 @@ def test_array_creates_implicit_groups(array):
9892
parent = "/".join(ancestry[: i + 1])
9993
if array.metadata.zarr_format == 2:
10094
assert (
101-
sync(
102-
array.store.get(
103-
f"{parent}/.zgroup", prototype=default_buffer_prototype()
104-
)
105-
)
95+
sync(array.store.get(f"{parent}/.zgroup", prototype=default_buffer_prototype()))
10696
is not None
10797
)
10898
elif array.metadata.zarr_format == 3:
10999
assert (
110-
sync(
111-
array.store.get(
112-
f"{parent}/zarr.json", prototype=default_buffer_prototype()
113-
)
114-
)
100+
sync(array.store.get(f"{parent}/zarr.json", prototype=default_buffer_prototype()))
115101
is not None
116102
)
117103

@@ -129,9 +115,7 @@ def test_basic_indexing(data: st.DataObject) -> None:
129115
actual = zarray[indexer]
130116
assert_array_equal(nparray[indexer], actual)
131117

132-
new_data = data.draw(
133-
numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype)
134-
)
118+
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
135119
zarray[indexer] = new_data
136120
nparray[indexer] = new_data
137121
assert_array_equal(nparray, zarray[:])
@@ -153,9 +137,7 @@ def test_oindex(data: st.DataObject) -> None:
153137
if isinstance(idxr, np.ndarray) and idxr.size != np.unique(idxr).size:
154138
# behaviour of setitem with repeated indices is not guaranteed in practice
155139
assume(False)
156-
new_data = data.draw(
157-
numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype)
158-
)
140+
new_data = data.draw(numpy_arrays(shapes=st.just(actual.shape), dtype=nparray.dtype))
159141
nparray[npindexer] = new_data
160142
zarray.oindex[zindexer] = new_data
161143
assert_array_equal(nparray, zarray[:])
@@ -231,33 +213,7 @@ def test_roundtrip_array_metadata_from_json(data: st.DataObject, zarr_format: in
231213
orig = metadata.to_dict()
232214
rt = metadata_roundtripped.to_dict()
233215

234-
assert deep_equal(
235-
orig, rt
236-
), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}"
237-
238-
239-
# @st.composite
240-
# def advanced_indices(draw, *, shape):
241-
# basic_idxr = draw(
242-
# basic_indices(
243-
# shape=shape, min_dims=len(shape), max_dims=len(shape), allow_ellipsis=False
244-
# ).filter(lambda x: isinstance(x, tuple))
245-
# )
246-
247-
# int_idxr = draw(
248-
# npst.integer_array_indices(shape=shape, result_shape=npst.array_shapes(max_dims=1))
249-
# )
250-
# args = tuple(
251-
# st.sampled_from((l, r)) for l, r in zip_longest(basic_idxr, int_idxr, fillvalue=slice(None))
252-
# )
253-
# return draw(st.tuples(*args))
254-
255-
256-
# @given(st.data())
257-
# def test_roundtrip_object_array(data):
258-
# nparray = data.draw(np_arrays)
259-
# zarray = data.draw(arrays(arrays=st.just(nparray)))
260-
# assert_array_equal(nparray, zarray[:])
216+
assert deep_equal(orig, rt), f"Roundtrip mismatch:\nOriginal: {orig}\nRoundtripped: {rt}"
261217

262218

263219
def serialized_complex_float_is_valid(
@@ -333,9 +289,7 @@ def test_array_metadata_meets_spec(meta: ArrayV2Metadata | ArrayV3Metadata) -> N
333289
# version-specific validations
334290
if isinstance(meta, ArrayV2Metadata):
335291
assert asdict_dict["filters"] != ()
336-
assert asdict_dict["filters"] is None or isinstance(
337-
asdict_dict["filters"], tuple
338-
)
292+
assert asdict_dict["filters"] is None or isinstance(asdict_dict["filters"], tuple)
339293
assert asdict_dict["zarr_format"] == 2
340294
else:
341295
assert asdict_dict["zarr_format"] == 3

0 commit comments

Comments
 (0)