@@ -234,7 +234,6 @@ def create_empty(cls, chunks_per_shard: tuple[int, ...]) -> _ShardIndex:
234234class _ShardReader (ShardMapping ):
235235 buf : Buffer
236236 index : _ShardIndex
237- order : SubchunkWriteOrder
238237
239238 @classmethod
240239 async def from_bytes (
@@ -540,8 +539,10 @@ async def _decode_partial_single(
540539 else :
541540 return out
542541
543- def _subchunk_order_iter (self , chunks_per_shard : tuple [int , ...]) -> Iterable [tuple [int , ...]]:
544- match self .subchunk_write_order :
542+ def _subchunk_order_iter (
543+ self , chunks_per_shard : tuple [int , ...], subchunk_write_order : SubchunkWriteOrder
544+ ) -> Iterable [tuple [int , ...]]:
545+ match subchunk_write_order :
545546 case "morton" :
546547 subchunk_iter = morton_order_iter (chunks_per_shard )
547548 case "lexicographic" :
@@ -574,10 +575,6 @@ async def _encode_single(
574575 )
575576 )
576577 shard_builder = dict .fromkeys (np .array (list (np .ndindex (chunks_per_shard ))))
577- assert (
578- shard_builder .keys ()
579- == dict .fromkeys (self ._subchunk_order_iter (chunks_per_shard )).keys ()
580- )
581578
582579 await self .codec_pipeline .write (
583580 [
@@ -618,7 +615,7 @@ async def _encode_partial_single(
618615 )
619616
620617 if self ._is_complete_shard_write (indexer , chunks_per_shard ):
621- shard_dict = dict .fromkeys (np . ndindex (chunks_per_shard ))
618+ shard_dict = dict .fromkeys (self . _subchunk_order_iter (chunks_per_shard , "lexicographic" ))
622619 else :
623620 shard_reader = await self ._load_full_shard_maybe (
624621 byte_getter = byte_setter ,
@@ -628,7 +625,7 @@ async def _encode_partial_single(
628625 shard_reader = shard_reader or _ShardReader .create_empty (chunks_per_shard )
629626 # Use vectorized lookup for better performance
630627 shard_dict = shard_reader .to_dict_vectorized (
631- np .array (list (np . ndindex (chunks_per_shard )))
628+ np .array (list (self . _subchunk_order_iter (chunks_per_shard , "lexicographic" )))
632629 )
633630
634631 await self .codec_pipeline .write (
@@ -667,7 +664,7 @@ async def _encode_shard_dict(
667664
668665 template = buffer_prototype .buffer .create_zero_length ()
669666 chunk_start = 0
670- for chunk_coords in self ._subchunk_order_iter (chunks_per_shard ):
667+ for chunk_coords in self ._subchunk_order_iter (chunks_per_shard , self . subchunk_write_order ):
671668 value = map .get (chunk_coords )
672669 if value is None :
673670 continue
0 commit comments