11from __future__ import annotations
22
33import random
4- from collections .abc import Iterable , Mapping , MutableMapping
4+ from collections .abc import Iterable , Mapping , MutableMapping , Sequence
55from dataclasses import dataclass , replace
66from enum import Enum
77from functools import lru_cache
4646from zarr .core .dtype .npy .int import UInt64
4747from zarr .core .indexing import (
4848 BasicIndexer ,
49+ ChunkProjection ,
4950 SelectorTuple ,
51+ _morton_order ,
5052 _morton_order_keys ,
5153 c_order_iter ,
5254 get_indexer ,
@@ -543,7 +545,7 @@ async def _decode_partial_single(
543545 else :
544546 return out
545547
546- def _subchunk_iter (self , chunks_per_shard : tuple [int , ...]) -> Iterable [tuple [int , ...]]:
548+ def _subchunk_order_iter (self , chunks_per_shard : tuple [int , ...]) -> Iterable [tuple [int , ...]]:
547549 match self .subchunk_write_order :
548550 case SubchunkWriteOrder .morton :
549551 subchunk_iter = morton_order_iter (chunks_per_shard )
@@ -557,6 +559,17 @@ def _subchunk_iter(self, chunks_per_shard: tuple[int, ...]) -> Iterable[tuple[in
557559 subchunk_iter = iter (subchunk_list )
558560 return subchunk_iter
559561
562+ def _subchunk_order_vectorized (self , chunks_per_shard : tuple [int , ...]) -> npt .NDArray [np .intp ]:
563+ match self .subchunk_write_order :
564+ case SubchunkWriteOrder .morton :
565+ subchunk_order_vectorized = _morton_order (chunks_per_shard )
566+ case _:
567+ subchunk_order_vectorized = np .fromiter (
568+ self ._subchunk_order_iter (chunks_per_shard ),
569+ dtype = np .dtype ((int , len (chunks_per_shard ))),
570+ )
571+ return subchunk_order_vectorized
572+
560573 async def _encode_single (
561574 self ,
562575 shard_array : NDBuffer ,
@@ -574,7 +587,7 @@ async def _encode_single(
574587 chunk_grid = RegularChunkGrid (chunk_shape = chunk_shape ),
575588 )
576589 )
577- shard_builder = dict .fromkeys (self ._subchunk_iter (chunks_per_shard ))
590+ shard_builder = dict .fromkeys (self ._subchunk_order_iter (chunks_per_shard ))
578591
579592 await self .codec_pipeline .write (
580593 [
@@ -608,23 +621,26 @@ async def _encode_partial_single(
608621 chunks_per_shard = self ._get_chunks_per_shard (shard_spec )
609622 chunk_spec = self ._get_chunk_spec (shard_spec )
610623
611- shard_reader = await self ._load_full_shard_maybe (
612- byte_getter = byte_setter ,
613- prototype = chunk_spec .prototype ,
614- chunks_per_shard = chunks_per_shard ,
615- )
616- shard_reader = shard_reader or _ShardReader .create_empty (chunks_per_shard )
617- # Use vectorized lookup for better performance
618- shard_dict = shard_reader .to_dict_vectorized (
619- np .asarray (list (self ._subchunk_iter (chunks_per_shard )))
620- )
621-
622624 indexer = list (
623625 get_indexer (
624626 selection , shape = shard_shape , chunk_grid = RegularChunkGrid (chunk_shape = chunk_shape )
625627 )
626628 )
627629
630+ if self ._is_complete_shard_write (indexer , chunks_per_shard ):
631+ shard_dict = dict .fromkeys (self ._subchunk_order_iter (chunks_per_shard ))
632+ else :
633+ shard_reader = await self ._load_full_shard_maybe (
634+ byte_getter = byte_setter ,
635+ prototype = chunk_spec .prototype ,
636+ chunks_per_shard = chunks_per_shard ,
637+ )
638+ shard_reader = shard_reader or _ShardReader .create_empty (chunks_per_shard )
639+ # Use vectorized lookup for better performance
640+ shard_dict = shard_reader .to_dict_vectorized (
641+ self ._subchunk_order_vectorized (chunks_per_shard )
642+ )
643+
628644 await self .codec_pipeline .write (
629645 [
630646 (
@@ -661,7 +677,7 @@ async def _encode_shard_dict(
661677
662678 template = buffer_prototype .buffer .create_zero_length ()
663679 chunk_start = 0
664- for chunk_coords in self ._subchunk_iter (chunks_per_shard ):
680+ for chunk_coords in self ._subchunk_order_iter (chunks_per_shard ):
665681 value = map .get (chunk_coords )
666682 if value is None :
667683 continue
@@ -697,6 +713,16 @@ def _is_total_shard(
697713 chunk_coords in all_chunk_coords for chunk_coords in c_order_iter (chunks_per_shard )
698714 )
699715
716+ def _is_complete_shard_write (
717+ self ,
718+ indexed_chunks : Sequence [ChunkProjection ],
719+ chunks_per_shard : tuple [int , ...],
720+ ) -> bool :
721+ all_chunk_coords = {chunk_coords for chunk_coords , * _ in indexed_chunks }
722+ return self ._is_total_shard (all_chunk_coords , chunks_per_shard ) and all (
723+ is_complete_chunk for * _ , is_complete_chunk in indexed_chunks
724+ )
725+
700726 async def _decode_shard_index (
701727 self , index_bytes : Buffer , chunks_per_shard : tuple [int , ...]
702728 ) -> _ShardIndex :
0 commit comments