11from __future__ import annotations
22
3+ import random
34from collections .abc import Iterable , Mapping , MutableMapping
45from dataclasses import dataclass , replace
56from enum import Enum
4647from zarr .core .indexing import (
4748 BasicIndexer ,
4849 SelectorTuple ,
49- _morton_order ,
5050 _morton_order_keys ,
5151 c_order_iter ,
5252 get_indexer ,
@@ -77,10 +77,27 @@ class ShardingCodecIndexLocation(Enum):
7777 end = "end"
7878
7979
80+ class SubchunkWriteOrder (Enum ):
81+ """
82+ Enum for the order of the chunks within a shard.
83+
84+ unordered is implemented via `random.shuffle` over the lexicographic order.
85+ """
86+
87+ morton = "morton"
88+ unordered = "unordered"
89+ lexicographic = "lexicographic"
90+ colexicographic = "colexicographic"
91+
92+
8093def parse_index_location (data : object ) -> ShardingCodecIndexLocation :
8194 return parse_enum (data , ShardingCodecIndexLocation )
8295
8396
97+ def parse_subchunk_write_order (data : object ) -> SubchunkWriteOrder :
98+ return parse_enum (data , SubchunkWriteOrder )
99+
100+
84101@dataclass (frozen = True )
85102class _ShardingByteGetter (ByteGetter ):
86103 shard_dict : ShardMapping
@@ -305,6 +322,7 @@ class ShardingCodec(
305322 codecs : tuple [Codec , ...]
306323 index_codecs : tuple [Codec , ...]
307324 index_location : ShardingCodecIndexLocation = ShardingCodecIndexLocation .end
325+ subchunk_write_order : SubchunkWriteOrder = SubchunkWriteOrder .morton
308326
309327 def __init__ (
310328 self ,
@@ -313,16 +331,19 @@ def __init__(
313331 codecs : Iterable [Codec | dict [str , JSON ]] = (BytesCodec (),),
314332 index_codecs : Iterable [Codec | dict [str , JSON ]] = (BytesCodec (), Crc32cCodec ()),
315333 index_location : ShardingCodecIndexLocation | str = ShardingCodecIndexLocation .end ,
334+ subchunk_write_order : SubchunkWriteOrder | str = SubchunkWriteOrder .morton ,
316335 ) -> None :
317336 chunk_shape_parsed = parse_shapelike (chunk_shape )
318337 codecs_parsed = parse_codecs (codecs )
319338 index_codecs_parsed = parse_codecs (index_codecs )
320339 index_location_parsed = parse_index_location (index_location )
340+ subchunk_write_order_parsed = parse_subchunk_write_order (subchunk_write_order )
321341
322342 object .__setattr__ (self , "chunk_shape" , chunk_shape_parsed )
323343 object .__setattr__ (self , "codecs" , codecs_parsed )
324344 object .__setattr__ (self , "index_codecs" , index_codecs_parsed )
325345 object .__setattr__ (self , "index_location" , index_location_parsed )
346+ object .__setattr__ (self , "subchunk_write_order" , subchunk_write_order_parsed )
326347
327348 # Use instance-local lru_cache to avoid memory leaks
328349
@@ -522,6 +543,20 @@ async def _decode_partial_single(
522543 else :
523544 return out
524545
546+ def _subchunk_iter (self , chunks_per_shard : tuple [int , ...]) -> Iterable [tuple [int , ...]]:
547+ match self .subchunk_write_order :
548+ case SubchunkWriteOrder .morton :
549+ subchunk_iter = morton_order_iter (chunks_per_shard )
550+ case SubchunkWriteOrder .lexicographic :
551+ subchunk_iter = np .ndindex (chunks_per_shard )
552+ case SubchunkWriteOrder .colexicographic :
553+ subchunk_iter = (c [::- 1 ] for c in np .ndindex (chunks_per_shard [::- 1 ]))
554+ case SubchunkWriteOrder .unordered :
555+ subchunk_list = list (np .ndindex (chunks_per_shard ))
556+ random .shuffle (subchunk_list )
557+ subchunk_iter = iter (subchunk_list )
558+ return subchunk_iter
559+
525560 async def _encode_single (
526561 self ,
527562 shard_array : NDBuffer ,
@@ -539,8 +574,7 @@ async def _encode_single(
539574 chunk_grid = RegularChunkGrid (chunk_shape = chunk_shape ),
540575 )
541576 )
542-
543- shard_builder = dict .fromkeys (morton_order_iter (chunks_per_shard ))
577+ shard_builder = dict .fromkeys (self ._subchunk_iter (chunks_per_shard ))
544578
545579 await self .codec_pipeline .write (
546580 [
@@ -581,7 +615,9 @@ async def _encode_partial_single(
581615 )
582616 shard_reader = shard_reader or _ShardReader .create_empty (chunks_per_shard )
583617 # Use vectorized lookup for better performance
584- shard_dict = shard_reader .to_dict_vectorized (np .asarray (_morton_order (chunks_per_shard )))
618+ shard_dict = shard_reader .to_dict_vectorized (
619+ np .asarray (list (self ._subchunk_iter (chunks_per_shard )))
620+ )
585621
586622 indexer = list (
587623 get_indexer (
@@ -625,7 +661,7 @@ async def _encode_shard_dict(
625661
626662 template = buffer_prototype .buffer .create_zero_length ()
627663 chunk_start = 0
628- for chunk_coords in morton_order_iter (chunks_per_shard ):
664+ for chunk_coords in self . _subchunk_iter (chunks_per_shard ):
629665 value = map .get (chunk_coords )
630666 if value is None :
631667 continue
0 commit comments