Skip to content

Commit b0c622d

Browse files
committed
refactor: no enums
1 parent 11b94c0 commit b0c622d

2 files changed

Lines changed: 25 additions & 32 deletions

File tree

src/zarr/codecs/sharding.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from enum import Enum
77
from functools import lru_cache
88
from operator import itemgetter
9-
from typing import TYPE_CHECKING, Any, NamedTuple, cast
9+
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, cast
1010

1111
import numpy as np
1212
import numpy.typing as npt
@@ -60,7 +60,7 @@
6060

6161
if TYPE_CHECKING:
6262
from collections.abc import Iterator
63-
from typing import Self
63+
from typing import Final, Self
6464

6565
from zarr.core.common import JSON
6666
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
@@ -79,27 +79,19 @@ class ShardingCodecIndexLocation(Enum):
7979
end = "end"
8080

8181

82-
class SubchunkWriteOrder(Enum):
83-
"""
84-
Enum for the order of the chunks within a shard.
85-
86-
unordered is implemented via `random.shuffle` over the lexicographic order.
87-
"""
88-
89-
morton = "morton"
90-
unordered = "unordered"
91-
lexicographic = "lexicographic"
92-
colexicographic = "colexicographic"
82+
SubchunkWriteOrder = Literal["morton", "unordered", "lexicographic", "colexicographic"]
83+
SUBCHUNK_WRITE_ORDER: Final[tuple[str, str, str, str]] = (
84+
"morton",
85+
"unordered",
86+
"lexicographic",
87+
"colexicographic",
88+
)
9389

9490

9591
def parse_index_location(data: object) -> ShardingCodecIndexLocation:
9692
return parse_enum(data, ShardingCodecIndexLocation)
9793

9894

99-
def parse_subchunk_write_order(data: object) -> SubchunkWriteOrder:
100-
return parse_enum(data, SubchunkWriteOrder)
101-
102-
10395
@dataclass(frozen=True)
10496
class _ShardingByteGetter(ByteGetter):
10597
shard_dict: ShardMapping
@@ -324,7 +316,7 @@ class ShardingCodec(
324316
codecs: tuple[Codec, ...]
325317
index_codecs: tuple[Codec, ...]
326318
index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end
327-
subchunk_write_order: SubchunkWriteOrder = SubchunkWriteOrder.morton
319+
subchunk_write_order: SubchunkWriteOrder = "morton"
328320

329321
def __init__(
330322
self,
@@ -333,19 +325,22 @@ def __init__(
333325
codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(),),
334326
index_codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(), Crc32cCodec()),
335327
index_location: ShardingCodecIndexLocation | str = ShardingCodecIndexLocation.end,
336-
subchunk_write_order: SubchunkWriteOrder | str = SubchunkWriteOrder.morton,
328+
subchunk_write_order: SubchunkWriteOrder = "morton",
337329
) -> None:
338330
chunk_shape_parsed = parse_shapelike(chunk_shape)
339331
codecs_parsed = parse_codecs(codecs)
340332
index_codecs_parsed = parse_codecs(index_codecs)
341333
index_location_parsed = parse_index_location(index_location)
342-
subchunk_write_order_parsed = parse_subchunk_write_order(subchunk_write_order)
334+
if subchunk_write_order not in SUBCHUNK_WRITE_ORDER:
335+
raise ValueError(
336+
f"Unrecognized subchunk write order: {subchunk_write_order}. Only {SUBCHUNK_WRITE_ORDER} are allowed."
337+
)
343338

344339
object.__setattr__(self, "chunk_shape", chunk_shape_parsed)
345340
object.__setattr__(self, "codecs", codecs_parsed)
346341
object.__setattr__(self, "index_codecs", index_codecs_parsed)
347342
object.__setattr__(self, "index_location", index_location_parsed)
348-
object.__setattr__(self, "subchunk_write_order", subchunk_write_order_parsed)
343+
object.__setattr__(self, "subchunk_write_order", subchunk_write_order)
349344

350345
# Use instance-local lru_cache to avoid memory leaks
351346

@@ -547,21 +542,21 @@ async def _decode_partial_single(
547542

548543
def _subchunk_order_iter(self, chunks_per_shard: tuple[int, ...]) -> Iterable[tuple[int, ...]]:
549544
match self.subchunk_write_order:
550-
case SubchunkWriteOrder.morton:
545+
case "morton":
551546
subchunk_iter = morton_order_iter(chunks_per_shard)
552-
case SubchunkWriteOrder.lexicographic:
547+
case "lexicographic":
553548
subchunk_iter = np.ndindex(chunks_per_shard)
554-
case SubchunkWriteOrder.colexicographic:
549+
case "colexicographic":
555550
subchunk_iter = (c[::-1] for c in np.ndindex(chunks_per_shard[::-1]))
556-
case SubchunkWriteOrder.unordered:
551+
case "unordered":
557552
subchunk_list = list(np.ndindex(chunks_per_shard))
558553
random.shuffle(subchunk_list)
559554
subchunk_iter = iter(subchunk_list)
560555
return subchunk_iter
561556

562557
def _subchunk_order_vectorized(self, chunks_per_shard: tuple[int, ...]) -> npt.NDArray[np.intp]:
563558
match self.subchunk_write_order:
564-
case SubchunkWriteOrder.morton:
559+
case "morton":
565560
subchunk_order_vectorized = _morton_order(chunks_per_shard)
566561
case _:
567562
subchunk_order_vectorized = np.fromiter(

tests/test_codecs/test_sharding.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pickle
22
import re
3-
from typing import Any
3+
from typing import Any, get_args
44

55
import numpy as np
66
import numpy.typing as npt
@@ -562,7 +562,7 @@ def test_sharding_mixed_integer_list_indexing(store: Store) -> None:
562562

563563
@pytest.mark.parametrize(
564564
"subchunk_write_order",
565-
list(SubchunkWriteOrder),
565+
get_args(SubchunkWriteOrder),
566566
)
567567
async def test_encoded_subchunk_write_order(subchunk_write_order: SubchunkWriteOrder) -> None:
568568
"""Subchunks must be physically laid out in the shard in the order specified by
@@ -611,14 +611,12 @@ async def test_encoded_subchunk_write_order(subchunk_write_order: SubchunkWriteO
611611
# The physical write order is recovered by sorting coordinates by start offset.
612612
actual_order = [coord for _, coord in sorted(offset_to_coord.items())]
613613
expected_order = list(codec._subchunk_order_iter(chunks_per_shard))
614-
assert (actual_order == expected_order) == (
615-
subchunk_write_order != SubchunkWriteOrder.unordered
616-
)
614+
assert (actual_order == expected_order) == (subchunk_write_order != "unordered")
617615

618616

619617
@pytest.mark.parametrize(
620618
"subchunk_write_order",
621-
list(SubchunkWriteOrder),
619+
get_args(SubchunkWriteOrder),
622620
)
623621
@pytest.mark.parametrize("do_partial", [True, False], ids=["partial", "complete"])
624622
def test_subchunk_write_order_roundtrip(

0 commit comments

Comments
 (0)