Skip to content

Commit 7b663ff

Browse files
committed
feat: deterministic but random order
1 parent a89249a commit 7b663ff

2 files changed

Lines changed: 73 additions & 40 deletions

File tree

src/zarr/codecs/sharding.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import random
43
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
54
from dataclasses import dataclass, replace
65
from enum import Enum
@@ -48,7 +47,6 @@
4847
BasicIndexer,
4948
ChunkProjection,
5049
SelectorTuple,
51-
_morton_order,
5250
_morton_order_keys,
5351
c_order_iter,
5452
get_indexer,
@@ -315,6 +313,7 @@ class ShardingCodec(
315313
chunk_shape: tuple[int, ...]
316314
codecs: tuple[Codec, ...]
317315
index_codecs: tuple[Codec, ...]
316+
rng: np.random.Generator | None
318317
index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end
319318
subchunk_write_order: SubchunkWriteOrder = "morton"
320319

@@ -326,6 +325,7 @@ def __init__(
326325
index_codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(), Crc32cCodec()),
327326
index_location: ShardingCodecIndexLocation | str = ShardingCodecIndexLocation.end,
328327
subchunk_write_order: SubchunkWriteOrder = "morton",
328+
rng: np.random.Generator | None = None,
329329
) -> None:
330330
chunk_shape_parsed = parse_shapelike(chunk_shape)
331331
codecs_parsed = parse_codecs(codecs)
@@ -341,6 +341,7 @@ def __init__(
341341
object.__setattr__(self, "index_codecs", index_codecs_parsed)
342342
object.__setattr__(self, "index_location", index_location_parsed)
343343
object.__setattr__(self, "subchunk_write_order", subchunk_write_order)
344+
object.__setattr__(self, "rng", rng)
344345

345346
# Use instance-local lru_cache to avoid memory leaks
346347

@@ -353,14 +354,15 @@ def __init__(
353354

354355
# todo: typedict return type
355356
def __getstate__(self) -> dict[str, Any]:
356-
return self.to_dict()
357+
return {"rng": self.rng, **self.to_dict()}
357358

358359
def __setstate__(self, state: dict[str, Any]) -> None:
359360
config = state["configuration"]
360361
object.__setattr__(self, "chunk_shape", parse_shapelike(config["chunk_shape"]))
361362
object.__setattr__(self, "codecs", parse_codecs(config["codecs"]))
362363
object.__setattr__(self, "index_codecs", parse_codecs(config["index_codecs"]))
363364
object.__setattr__(self, "index_location", parse_index_location(config["index_location"]))
365+
object.__setattr__(self, "rng", state["rng"])
364366

365367
# Use instance-local lru_cache to avoid memory leaks
366368
# object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
@@ -550,21 +552,12 @@ def _subchunk_order_iter(self, chunks_per_shard: tuple[int, ...]) -> Iterable[tu
550552
subchunk_iter = (c[::-1] for c in np.ndindex(chunks_per_shard[::-1]))
551553
case "unordered":
552554
subchunk_list = list(np.ndindex(chunks_per_shard))
553-
random.shuffle(subchunk_list)
555+
(self.rng if self.rng is not None else np.random.default_rng()).shuffle(
556+
subchunk_list
557+
)
554558
subchunk_iter = iter(subchunk_list)
555559
return subchunk_iter
556560

557-
def _subchunk_order_vectorized(self, chunks_per_shard: tuple[int, ...]) -> npt.NDArray[np.intp]:
558-
match self.subchunk_write_order:
559-
case "morton":
560-
subchunk_order_vectorized = _morton_order(chunks_per_shard)
561-
case _:
562-
subchunk_order_vectorized = np.fromiter(
563-
self._subchunk_order_iter(chunks_per_shard),
564-
dtype=np.dtype((int, len(chunks_per_shard))),
565-
)
566-
return subchunk_order_vectorized
567-
568561
async def _encode_single(
569562
self,
570563
shard_array: NDBuffer,
@@ -623,7 +616,7 @@ async def _encode_partial_single(
623616
)
624617

625618
if self._is_complete_shard_write(indexer, chunks_per_shard):
626-
shard_dict = dict.fromkeys(self._subchunk_order_iter(chunks_per_shard))
619+
shard_dict = dict.fromkeys(np.ndindex(chunks_per_shard))
627620
else:
628621
shard_reader = await self._load_full_shard_maybe(
629622
byte_getter=byte_setter,
@@ -633,7 +626,7 @@ async def _encode_partial_single(
633626
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
634627
# Use vectorized lookup for better performance
635628
shard_dict = shard_reader.to_dict_vectorized(
636-
self._subchunk_order_vectorized(chunks_per_shard)
629+
np.array(list(np.ndindex(chunks_per_shard)))
637630
)
638631

639632
await self.codec_pipeline.write(

tests/test_codecs/test_sharding.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -560,26 +560,10 @@ def test_sharding_mixed_integer_list_indexing(store: Store) -> None:
560560
np.testing.assert_array_equal(c3, s3)
561561

562562

563-
@pytest.mark.parametrize(
564-
"subchunk_write_order",
565-
get_args(SubchunkWriteOrder),
566-
)
567-
async def test_encoded_subchunk_write_order(subchunk_write_order: SubchunkWriteOrder) -> None:
568-
"""Subchunks must be physically laid out in the shard in the order specified by
569-
``subchunk_write_order``. We verify this by decoding the shard index and sorting
570-
the chunk coordinates by their byte offset."""
571-
# Use a non-square chunks_per_shard so all three orderings are distinguishable.
572-
chunks_per_shard = (3, 2)
573-
chunk_shape = (4, 4)
574-
shard_shape = tuple(c * s for c, s in zip(chunks_per_shard, chunk_shape, strict=True))
575-
576-
codec = ShardingCodec(
577-
chunk_shape=chunk_shape,
578-
codecs=[BytesCodec()],
579-
index_codecs=[BytesCodec(), Crc32cCodec()],
580-
index_location=ShardingCodecIndexLocation.end,
581-
subchunk_write_order=subchunk_write_order,
582-
)
563+
async def stored_data_and_get_order(
564+
codec: ShardingCodec, chunks_per_shard: tuple[int, ...]
565+
) -> list[tuple[int, ...]]:
566+
shard_shape = tuple(c * s for c, s in zip(chunks_per_shard, codec.chunk_shape, strict=True))
583567
store = MemoryStore()
584568
arr = zarr.create_array(
585569
StorePath(store),
@@ -609,9 +593,65 @@ async def test_encoded_subchunk_write_order(subchunk_write_order: SubchunkWriteO
609593
)
610594

611595
# The physical write order is recovered by sorting coordinates by start offset.
612-
actual_order = [coord for _, coord in sorted(offset_to_coord.items())]
613-
expected_order = list(codec._subchunk_order_iter(chunks_per_shard))
614-
assert (actual_order == expected_order) == (subchunk_write_order != "unordered")
596+
return [coord for _, coord in sorted(offset_to_coord.items())]
597+
598+
599+
@pytest.mark.parametrize(
600+
"subchunk_write_order",
601+
get_args(SubchunkWriteOrder),
602+
)
603+
async def test_encoded_subchunk_write_order(subchunk_write_order: SubchunkWriteOrder) -> None:
604+
"""Subchunks must be physically laid out in the shard in the order specified by
605+
``subchunk_write_order``. We verify this by decoding the shard index and sorting
606+
the chunk coordinates by their byte offset."""
607+
# Use a non-square chunks_per_shard so all three orderings are distinguishable.
608+
chunks_per_shard = (3, 2)
609+
chunk_shape = (4, 4)
610+
seed = 0
611+
codec = ShardingCodec(
612+
chunk_shape=chunk_shape,
613+
codecs=[BytesCodec()],
614+
index_codecs=[BytesCodec(), Crc32cCodec()],
615+
index_location=ShardingCodecIndexLocation.end,
616+
subchunk_write_order=subchunk_write_order,
617+
rng=np.random.default_rng(seed=seed),
618+
)
619+
620+
actual_order = await stored_data_and_get_order(codec, chunks_per_shard)
621+
if subchunk_write_order != "unordered":
622+
expected_order = list(codec._subchunk_order_iter(chunks_per_shard))
623+
assert actual_order == expected_order
624+
else:
625+
same_order_same_seed = list(
626+
ShardingCodec(
627+
chunk_shape=chunk_shape,
628+
codecs=[BytesCodec()],
629+
index_codecs=[BytesCodec(), Crc32cCodec()],
630+
index_location=ShardingCodecIndexLocation.end,
631+
subchunk_write_order=subchunk_write_order,
632+
rng=np.random.default_rng(seed=seed),
633+
)._subchunk_order_iter(chunks_per_shard)
634+
)
635+
assert actual_order == same_order_same_seed
636+
637+
638+
async def test_unordered_can_be_seeded() -> None:
639+
orders = []
640+
chunks_per_shard = (3, 2)
641+
chunk_shape = (4, 4)
642+
seed = 0
643+
for _ in range(4):
644+
codec = ShardingCodec(
645+
chunk_shape=chunk_shape,
646+
codecs=[BytesCodec()],
647+
index_codecs=[BytesCodec(), Crc32cCodec()],
648+
index_location=ShardingCodecIndexLocation.end,
649+
subchunk_write_order="unordered",
650+
rng=np.random.default_rng(seed=seed),
651+
)
652+
# The physical write order is recovered by sorting coordinates by start offset.
653+
orders.append(await stored_data_and_get_order(codec, chunks_per_shard))
654+
assert all(orders[0] == o for o in orders)
615655

616656

617657
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)