11from __future__ import annotations
22
3- import random
43from collections .abc import Iterable , Mapping , MutableMapping , Sequence
54from dataclasses import dataclass , replace
65from enum import Enum
4847 BasicIndexer ,
4948 ChunkProjection ,
5049 SelectorTuple ,
51- _morton_order ,
5250 _morton_order_keys ,
5351 c_order_iter ,
5452 get_indexer ,
@@ -315,6 +313,7 @@ class ShardingCodec(
315313 chunk_shape : tuple [int , ...]
316314 codecs : tuple [Codec , ...]
317315 index_codecs : tuple [Codec , ...]
316+ rng : np .random .Generator | None
318317 index_location : ShardingCodecIndexLocation = ShardingCodecIndexLocation .end
319318 subchunk_write_order : SubchunkWriteOrder = "morton"
320319
@@ -326,6 +325,7 @@ def __init__(
326325 index_codecs : Iterable [Codec | dict [str , JSON ]] = (BytesCodec (), Crc32cCodec ()),
327326 index_location : ShardingCodecIndexLocation | str = ShardingCodecIndexLocation .end ,
328327 subchunk_write_order : SubchunkWriteOrder = "morton" ,
328+ rng : np .random .Generator | None = None ,
329329 ) -> None :
330330 chunk_shape_parsed = parse_shapelike (chunk_shape )
331331 codecs_parsed = parse_codecs (codecs )
@@ -341,6 +341,7 @@ def __init__(
341341 object .__setattr__ (self , "index_codecs" , index_codecs_parsed )
342342 object .__setattr__ (self , "index_location" , index_location_parsed )
343343 object .__setattr__ (self , "subchunk_write_order" , subchunk_write_order )
344+ object .__setattr__ (self , "rng" , rng )
344345
345346 # Use instance-local lru_cache to avoid memory leaks
346347
@@ -353,14 +354,15 @@ def __init__(
353354
354355 # todo: typedict return type
355356 def __getstate__ (self ) -> dict [str , Any ]:
356- return self .to_dict ()
357+ return { "rng" : self .rng , ** self . to_dict ()}
357358
358359 def __setstate__ (self , state : dict [str , Any ]) -> None :
359360 config = state ["configuration" ]
360361 object .__setattr__ (self , "chunk_shape" , parse_shapelike (config ["chunk_shape" ]))
361362 object .__setattr__ (self , "codecs" , parse_codecs (config ["codecs" ]))
362363 object .__setattr__ (self , "index_codecs" , parse_codecs (config ["index_codecs" ]))
363364 object .__setattr__ (self , "index_location" , parse_index_location (config ["index_location" ]))
365+ object .__setattr__ (self , "rng" , state ["rng" ])
364366
365367 # Use instance-local lru_cache to avoid memory leaks
366368 # object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
@@ -550,21 +552,12 @@ def _subchunk_order_iter(self, chunks_per_shard: tuple[int, ...]) -> Iterable[tu
550552 subchunk_iter = (c [::- 1 ] for c in np .ndindex (chunks_per_shard [::- 1 ]))
551553 case "unordered" :
552554 subchunk_list = list (np .ndindex (chunks_per_shard ))
553- random .shuffle (subchunk_list )
555+ (self .rng if self .rng is not None else np .random .default_rng ()).shuffle (
556+ subchunk_list
557+ )
554558 subchunk_iter = iter (subchunk_list )
555559 return subchunk_iter
556560
557- def _subchunk_order_vectorized (self , chunks_per_shard : tuple [int , ...]) -> npt .NDArray [np .intp ]:
558- match self .subchunk_write_order :
559- case "morton" :
560- subchunk_order_vectorized = _morton_order (chunks_per_shard )
561- case _:
562- subchunk_order_vectorized = np .fromiter (
563- self ._subchunk_order_iter (chunks_per_shard ),
564- dtype = np .dtype ((int , len (chunks_per_shard ))),
565- )
566- return subchunk_order_vectorized
567-
568561 async def _encode_single (
569562 self ,
570563 shard_array : NDBuffer ,
@@ -623,7 +616,7 @@ async def _encode_partial_single(
623616 )
624617
625618 if self ._is_complete_shard_write (indexer , chunks_per_shard ):
626- shard_dict = dict .fromkeys (self . _subchunk_order_iter (chunks_per_shard ))
619+ shard_dict = dict .fromkeys (np . ndindex (chunks_per_shard ))
627620 else :
628621 shard_reader = await self ._load_full_shard_maybe (
629622 byte_getter = byte_setter ,
@@ -633,7 +626,7 @@ async def _encode_partial_single(
633626 shard_reader = shard_reader or _ShardReader .create_empty (chunks_per_shard )
634627 # Use vectorized lookup for better performance
635628 shard_dict = shard_reader .to_dict_vectorized (
636- self . _subchunk_order_vectorized ( chunks_per_shard )
629+ np . array ( list ( np . ndindex ( chunks_per_shard )) )
637630 )
638631
639632 await self .codec_pipeline .write (
0 commit comments