Skip to content

Commit 5477d70

Browse files
committed
feat: subchunk write order
1 parent 46843be commit 5477d70

2 files changed

Lines changed: 132 additions & 6 deletions

File tree

src/zarr/codecs/sharding.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import random
34
from collections.abc import Iterable, Mapping, MutableMapping
45
from dataclasses import dataclass, replace
56
from enum import Enum
@@ -46,7 +47,6 @@
4647
from zarr.core.indexing import (
4748
BasicIndexer,
4849
SelectorTuple,
49-
_morton_order,
5050
_morton_order_keys,
5151
c_order_iter,
5252
get_indexer,
@@ -77,10 +77,27 @@ class ShardingCodecIndexLocation(Enum):
7777
end = "end"
7878

7979

80+
class SubchunkWriteOrder(Enum):
81+
"""
82+
Enum for the order of the chunks within a shard.
83+
84+
unordered is implemented via `random.shuffle` over the lexicographic order.
85+
"""
86+
87+
morton = "morton"
88+
unordered = "unordered"
89+
lexicographic = "lexicographic"
90+
colexicographic = "colexicographic"
91+
92+
8093
def parse_index_location(data: object) -> ShardingCodecIndexLocation:
8194
return parse_enum(data, ShardingCodecIndexLocation)
8295

8396

97+
def parse_subchunk_write_order(data: object) -> SubchunkWriteOrder:
98+
return parse_enum(data, SubchunkWriteOrder)
99+
100+
84101
@dataclass(frozen=True)
85102
class _ShardingByteGetter(ByteGetter):
86103
shard_dict: ShardMapping
@@ -305,6 +322,7 @@ class ShardingCodec(
305322
codecs: tuple[Codec, ...]
306323
index_codecs: tuple[Codec, ...]
307324
index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end
325+
subchunk_write_order: SubchunkWriteOrder = SubchunkWriteOrder.morton
308326

309327
def __init__(
310328
self,
@@ -313,16 +331,19 @@ def __init__(
313331
codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(),),
314332
index_codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(), Crc32cCodec()),
315333
index_location: ShardingCodecIndexLocation | str = ShardingCodecIndexLocation.end,
334+
subchunk_write_order: SubchunkWriteOrder | str = SubchunkWriteOrder.morton,
316335
) -> None:
317336
chunk_shape_parsed = parse_shapelike(chunk_shape)
318337
codecs_parsed = parse_codecs(codecs)
319338
index_codecs_parsed = parse_codecs(index_codecs)
320339
index_location_parsed = parse_index_location(index_location)
340+
subchunk_write_order_parsed = parse_subchunk_write_order(subchunk_write_order)
321341

322342
object.__setattr__(self, "chunk_shape", chunk_shape_parsed)
323343
object.__setattr__(self, "codecs", codecs_parsed)
324344
object.__setattr__(self, "index_codecs", index_codecs_parsed)
325345
object.__setattr__(self, "index_location", index_location_parsed)
346+
object.__setattr__(self, "subchunk_write_order", subchunk_write_order_parsed)
326347

327348
# Use instance-local lru_cache to avoid memory leaks
328349

@@ -522,6 +543,20 @@ async def _decode_partial_single(
522543
else:
523544
return out
524545

546+
def _subchunk_iter(self, chunks_per_shard: tuple[int, ...]) -> Iterable[tuple[int, ...]]:
547+
match self.subchunk_write_order:
548+
case SubchunkWriteOrder.morton:
549+
subchunk_iter = morton_order_iter(chunks_per_shard)
550+
case SubchunkWriteOrder.lexicographic:
551+
subchunk_iter = np.ndindex(chunks_per_shard)
552+
case SubchunkWriteOrder.colexicographic:
553+
subchunk_iter = (c[::-1] for c in np.ndindex(chunks_per_shard[::-1]))
554+
case SubchunkWriteOrder.unordered:
555+
subchunk_list = list(np.ndindex(chunks_per_shard))
556+
random.shuffle(subchunk_list)
557+
subchunk_iter = iter(subchunk_list)
558+
return subchunk_iter
559+
525560
async def _encode_single(
526561
self,
527562
shard_array: NDBuffer,
@@ -539,8 +574,7 @@ async def _encode_single(
539574
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
540575
)
541576
)
542-
543-
shard_builder = dict.fromkeys(morton_order_iter(chunks_per_shard))
577+
shard_builder = dict.fromkeys(self._subchunk_iter(chunks_per_shard))
544578

545579
await self.codec_pipeline.write(
546580
[
@@ -581,7 +615,9 @@ async def _encode_partial_single(
581615
)
582616
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
583617
# Use vectorized lookup for better performance
584-
shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard)))
618+
shard_dict = shard_reader.to_dict_vectorized(
619+
np.asarray(list(self._subchunk_iter(chunks_per_shard)))
620+
)
585621

586622
indexer = list(
587623
get_indexer(
@@ -625,7 +661,7 @@ async def _encode_shard_dict(
625661

626662
template = buffer_prototype.buffer.create_zero_length()
627663
chunk_start = 0
628-
for chunk_coords in morton_order_iter(chunks_per_shard):
664+
for chunk_coords in self._subchunk_iter(chunks_per_shard):
629665
value = map.get(chunk_coords)
630666
if value is None:
631667
continue

tests/test_codecs/test_sharding.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
from zarr.abc.store import Store
1414
from zarr.codecs import (
1515
BloscCodec,
16+
BytesCodec,
17+
Crc32cCodec,
1618
ShardingCodec,
1719
ShardingCodecIndexLocation,
1820
TransposeCodec,
1921
)
22+
from zarr.codecs.sharding import SubchunkWriteOrder, _ShardReader
2023
from zarr.core.buffer import NDArrayLike, default_buffer_prototype
21-
from zarr.storage import StorePath, ZipStore
24+
from zarr.storage import MemoryStore, StorePath, ZipStore
2225

2326
from ..conftest import ArrayRequest
2427
from .test_codecs import _AsyncArrayProxy, order_from_dim
@@ -555,3 +558,90 @@ def test_sharding_mixed_integer_list_indexing(store: Store) -> None:
555558
s3 = sharded[0:5, 1, 0:3]
556559
assert c3.shape == s3.shape == (5, 3) # type: ignore[union-attr]
557560
np.testing.assert_array_equal(c3, s3)
561+
562+
563+
@pytest.mark.parametrize(
564+
"subchunk_write_order",
565+
list(SubchunkWriteOrder),
566+
)
567+
async def test_encoded_subchunk_write_order(
568+
subchunk_write_order: SubchunkWriteOrder,
569+
) -> None:
570+
"""Subchunks must be physically laid out in the shard in the order specified by
571+
``subchunk_write_order``. We verify this by decoding the shard index and sorting
572+
the chunk coordinates by their byte offset."""
573+
# Use a non-square chunks_per_shard so all three orderings are distinguishable.
574+
chunks_per_shard = (3, 2)
575+
chunk_shape = (4, 4)
576+
shard_shape = tuple(c * s for c, s in zip(chunks_per_shard, chunk_shape, strict=True))
577+
578+
codec = ShardingCodec(
579+
chunk_shape=chunk_shape,
580+
codecs=[BytesCodec()],
581+
index_codecs=[BytesCodec(), Crc32cCodec()],
582+
index_location=ShardingCodecIndexLocation.end,
583+
subchunk_write_order=subchunk_write_order,
584+
)
585+
store = MemoryStore()
586+
arr = zarr.create_array(
587+
StorePath(store),
588+
shape=shard_shape,
589+
dtype="uint8",
590+
chunks=shard_shape,
591+
serializer=codec,
592+
filters=None,
593+
compressors=None,
594+
fill_value=0,
595+
)
596+
597+
arr[:] = np.arange(np.prod(shard_shape), dtype="uint8").reshape(shard_shape)
598+
599+
shard_buf = await store.get("c/0/0", prototype=default_buffer_prototype())
600+
if shard_buf is None:
601+
raise RuntimeError("data write failed")
602+
index = (await _ShardReader.from_bytes(shard_buf, codec, chunks_per_shard)).index
603+
offset_to_coord: dict[int, tuple[int, ...]] = dict(
604+
zip(
605+
index.get_chunk_slices_vectorized(np.array(list(np.ndindex(chunks_per_shard))))[
606+
0
607+
], # start
608+
list(np.ndindex(chunks_per_shard)), # coord
609+
strict=True,
610+
)
611+
)
612+
613+
# The physical write order is recovered by sorting coordinates by start offset.
614+
actual_order = [coord for _, coord in sorted(offset_to_coord.items())]
615+
expected_order = list(codec._subchunk_iter(chunks_per_shard))
616+
assert (actual_order == expected_order) == (
617+
subchunk_write_order != SubchunkWriteOrder.unordered
618+
)
619+
620+
621+
@pytest.mark.parametrize(
622+
"subchunk_write_order",
623+
list(SubchunkWriteOrder),
624+
)
625+
def test_subchunk_write_order_roundtrip(subchunk_write_order: SubchunkWriteOrder) -> None:
626+
"""Data written with any ``subchunk_write_order`` must round-trip correctly."""
627+
chunks_per_shard = (3, 2)
628+
chunk_shape = (4, 4)
629+
shard_shape = tuple(c * s for c, s in zip(chunks_per_shard, chunk_shape, strict=True))
630+
data = np.arange(np.prod(shard_shape), dtype="uint16").reshape(shard_shape)
631+
632+
arr = zarr.create_array(
633+
StorePath(MemoryStore()),
634+
shape=shard_shape,
635+
dtype=data.dtype,
636+
chunks=shard_shape,
637+
serializer=ShardingCodec(
638+
chunk_shape=chunk_shape,
639+
codecs=[BytesCodec()],
640+
subchunk_write_order=subchunk_write_order,
641+
),
642+
filters=None,
643+
compressors=None,
644+
fill_value=0,
645+
)
646+
arr[:] = data
647+
np.testing.assert_array_equal(arr[:], data)

0 commit comments

Comments
 (0)