@@ -877,6 +877,152 @@ def decode_chunks_from_index(
877877 return out
878878
879879
880+ def merge_and_encode_from_index (
881+ existing_raw : dict [tuple [int , ...], Buffer | None ],
882+ index : ShardIndex ,
883+ value : NDBuffer ,
884+ chunk_spec : ArraySpec ,
885+ chunk_selection : SelectorTuple ,
886+ out_selection : SelectorTuple ,
887+ drop_axes : tuple [int , ...],
888+ ) -> dict [tuple [int , ...], Buffer | None ]:
889+ """Merge new data into existing chunk(s) and encode, using index.leaf_transform.
890+
891+ For non-sharded layouts (``index.is_sharded`` is False): decode the single
892+ existing chunk (or create from fill value), merge *value* at the given
893+ selection, and encode. Returns ``{(0,...): encoded}``.
894+
895+ For sharded layouts (``index.is_sharded`` is True): start with existing raw
896+ chunks, fill missing coords with None, then iterate over affected inner
897+ chunks using ``get_indexer``. Decode/merge/encode each. Returns the full
898+ chunk dict for subsequent packing into a shard blob.
899+ """
900+ from zarr .core .indexing import get_indexer
901+
902+ assert index .leaf_transform is not None
903+ transform = index .leaf_transform
904+
905+ if not index .is_sharded :
906+ # --- Simple (non-sharded) path ---
907+ coord = next (iter (existing_raw )) if existing_raw else (0 ,) * len (chunk_spec .shape )
908+
909+ existing_bytes = existing_raw .get (coord )
910+ if existing_bytes is not None :
911+ chunk_array = transform .decode_chunk (existing_bytes , chunk_shape = chunk_spec .shape )
912+ if not chunk_array .as_ndarray_like ().flags .writeable : # type: ignore[attr-defined]
913+ chunk_array = chunk_spec .prototype .nd_buffer .from_ndarray_like (
914+ chunk_array .as_ndarray_like ().copy ()
915+ )
916+ else :
917+ chunk_array = chunk_spec .prototype .nd_buffer .create (
918+ shape = chunk_spec .shape ,
919+ dtype = chunk_spec .dtype .to_native_dtype (),
920+ fill_value = fill_value_or_default (chunk_spec ),
921+ )
922+
923+ # Merge value
924+ if chunk_selection == () or is_scalar (
925+ value .as_ndarray_like (), chunk_spec .dtype .to_native_dtype ()
926+ ):
927+ chunk_value = value
928+ else :
929+ chunk_value = value [out_selection ]
930+ if drop_axes :
931+ item = tuple (
932+ None if idx in drop_axes else slice (None ) for idx in range (chunk_spec .ndim )
933+ )
934+ chunk_value = chunk_value [item ]
935+ chunk_array [chunk_selection ] = chunk_value
936+
937+ # Check write_empty_chunks
938+ if not chunk_spec .config .write_empty_chunks and chunk_array .all_equal (
939+ chunk_spec .fill_value
940+ ):
941+ return {coord : None }
942+
943+ chunk_shape = chunk_spec .shape if chunk_spec .shape != transform .array_spec .shape else None
944+ encoded = transform .encode_chunk (chunk_array , chunk_shape = chunk_shape )
945+ return {coord : encoded }
946+
947+ # --- Sharded path ---
948+ inner_shape = transform .array_spec .shape
949+ chunks_per_shard = tuple (
950+ s // cs for s , cs in zip (chunk_spec .shape , inner_shape , strict = True )
951+ )
952+
953+ chunk_dict : dict [tuple [int , ...], Buffer | None ] = dict (existing_raw )
954+
955+ # Fill missing coords with None
956+ for coord in np .ndindex (chunks_per_shard ):
957+ if coord not in chunk_dict :
958+ chunk_dict [coord ] = None
959+
960+ inner_spec = ArraySpec (
961+ shape = inner_shape ,
962+ dtype = chunk_spec .dtype ,
963+ fill_value = chunk_spec .fill_value ,
964+ config = chunk_spec .config ,
965+ prototype = chunk_spec .prototype ,
966+ )
967+
968+ # Extract the shard's portion of the write value
969+ if is_scalar (value .as_ndarray_like (), chunk_spec .dtype .to_native_dtype ()):
970+ shard_value = value
971+ else :
972+ shard_value = value [out_selection ]
973+ if drop_axes :
974+ item = tuple (
975+ None if idx in drop_axes else slice (None )
976+ for idx in range (len (chunk_spec .shape ))
977+ )
978+ shard_value = shard_value [item ]
979+
980+ # Determine which inner chunks are affected
981+ from zarr .core .chunk_grids import ChunkGrid as _ChunkGrid
982+
983+ indexer = get_indexer (
984+ chunk_selection ,
985+ shape = chunk_spec .shape ,
986+ chunk_grid = _ChunkGrid .from_sizes (chunk_spec .shape , inner_shape ),
987+ )
988+
989+ for inner_coords , inner_sel , value_sel , _ in indexer :
990+ existing_bytes = chunk_dict .get (inner_coords )
991+
992+ # Decode just this inner chunk
993+ if existing_bytes is not None :
994+ inner_array = transform .decode_chunk (existing_bytes )
995+ if not inner_array .as_ndarray_like ().flags .writeable : # type: ignore[attr-defined]
996+ inner_array = inner_spec .prototype .nd_buffer .from_ndarray_like (
997+ inner_array .as_ndarray_like ().copy ()
998+ )
999+ else :
1000+ inner_array = inner_spec .prototype .nd_buffer .create (
1001+ shape = inner_spec .shape ,
1002+ dtype = inner_spec .dtype .to_native_dtype (),
1003+ fill_value = fill_value_or_default (inner_spec ),
1004+ )
1005+
1006+ # Merge new data
1007+ if inner_sel == () or is_scalar (
1008+ shard_value .as_ndarray_like (), inner_spec .dtype .to_native_dtype ()
1009+ ):
1010+ inner_value = shard_value
1011+ else :
1012+ inner_value = shard_value [value_sel ]
1013+ inner_array [inner_sel ] = inner_value
1014+
1015+ # Re-encode
1016+ if not chunk_spec .config .write_empty_chunks and inner_array .all_equal (
1017+ chunk_spec .fill_value
1018+ ):
1019+ chunk_dict [inner_coords ] = None
1020+ else :
1021+ chunk_dict [inner_coords ] = transform .encode_chunk (inner_array )
1022+
1023+ return chunk_dict
1024+
1025+
8801026class ChunkLayout :
8811027 """Describes how a stored blob maps to one or more inner chunks.
8821028
@@ -934,6 +1080,12 @@ def store_chunks_sync(self, byte_setter: Any, encoded_chunks: dict[tuple[int, ..
9341080 async def store_chunks_async (self , byte_setter : Any , encoded_chunks : dict [tuple [int , ...], Buffer | None ], chunk_spec : ArraySpec ) -> None :
9351081 raise NotImplementedError
9361082
1083+ def pack_and_store_sync (self , byte_setter : Any , encoded_chunks : dict [tuple [int , ...], Buffer | None ]) -> None :
1084+ raise NotImplementedError
1085+
1086+ async def pack_and_store_async (self , byte_setter : Any , encoded_chunks : dict [tuple [int , ...], Buffer | None ]) -> None :
1087+ raise NotImplementedError
1088+
9371089 # -- Low-level helpers --
9381090
9391091 def unpack_blob (self , blob : Buffer ) -> dict [tuple [int , ...], Buffer | None ]:
@@ -1054,6 +1206,22 @@ async def store_chunks_async(self, byte_setter: Any, encoded_chunks: dict[tuple[
10541206 else :
10551207 await byte_setter .set (blob )
10561208
1209+ def pack_and_store_sync (self , byte_setter : Any , encoded_chunks : dict [tuple [int , ...], Buffer | None ]) -> None :
1210+ coord = (0 ,) * len (self .chunks_per_shard )
1211+ blob = encoded_chunks .get (coord )
1212+ if blob is None :
1213+ byte_setter .delete_sync () # type: ignore[attr-defined]
1214+ else :
1215+ byte_setter .set_sync (blob ) # type: ignore[attr-defined]
1216+
1217+ async def pack_and_store_async (self , byte_setter : Any , encoded_chunks : dict [tuple [int , ...], Buffer | None ]) -> None :
1218+ coord = (0 ,) * len (self .chunks_per_shard )
1219+ blob = encoded_chunks .get (coord )
1220+ if blob is None :
1221+ await byte_setter .delete ()
1222+ else :
1223+ await byte_setter .set (blob )
1224+
10571225 # -- Low-level --
10581226
10591227 def unpack_blob (self , blob : Buffer ) -> dict [tuple [int , ...], Buffer | None ]:
@@ -1347,6 +1515,30 @@ async def store_chunks_async(self, byte_setter: Any, encoded_chunks: dict[tuple[
13471515 else :
13481516 await byte_setter .set (blob )
13491517
1518+ def pack_and_store_sync (self , byte_setter : Any , encoded_chunks : dict [tuple [int , ...], Buffer | None ]) -> None :
1519+ from zarr .core .buffer import default_buffer_prototype
1520+
1521+ if all (v is None for v in encoded_chunks .values ()):
1522+ byte_setter .delete_sync () # type: ignore[attr-defined]
1523+ return
1524+ blob = self .pack_blob (encoded_chunks , default_buffer_prototype ())
1525+ if blob is None :
1526+ byte_setter .delete_sync () # type: ignore[attr-defined]
1527+ else :
1528+ byte_setter .set_sync (blob ) # type: ignore[attr-defined]
1529+
1530+ async def pack_and_store_async (self , byte_setter : Any , encoded_chunks : dict [tuple [int , ...], Buffer | None ]) -> None :
1531+ from zarr .core .buffer import default_buffer_prototype
1532+
1533+ if all (v is None for v in encoded_chunks .values ()):
1534+ await byte_setter .delete ()
1535+ return
1536+ blob = self .pack_blob (encoded_chunks , default_buffer_prototype ())
1537+ if blob is None :
1538+ await byte_setter .delete ()
1539+ else :
1540+ await byte_setter .set (blob )
1541+
13501542 def _decode_per_chunk (
13511543 self ,
13521544 chunk_dict : dict [tuple [int , ...], Buffer | None ],
@@ -1760,7 +1952,7 @@ async def _process_chunk(
17601952
17611953 # Phase 1: resolve index (IO)
17621954 if is_complete :
1763- index = ShardIndex (key = key )
1955+ index = ShardIndex (key = key , leaf_transform = layout . inner_transform , is_sharded = layout . is_sharded )
17641956 elif layout .is_sharded :
17651957 async with sem :
17661958 index = await layout .resolve_index_async (byte_setter , key , chunk_selection = None ) # ALL coords
@@ -1771,25 +1963,26 @@ async def _process_chunk(
17711963 # Phase 2: fetch existing chunks (IO)
17721964 if index .chunks :
17731965 async with sem :
1774- existing_chunks = await layout . fetch_chunks_async (byte_setter , index , prototype = chunk_spec .prototype )
1966+ existing = await fetch_chunks_async (byte_setter , index , prototype = chunk_spec .prototype )
17751967 else :
1776- existing_chunks = {}
1968+ existing = {}
17771969
17781970 # Phase 3: merge and encode (compute)
1779- encoded_chunks = await loop .run_in_executor (
1971+ encoded = await loop .run_in_executor (
17801972 pool ,
1781- layout .merge_and_encode ,
1782- existing_chunks ,
1973+ merge_and_encode_from_index ,
1974+ existing ,
1975+ index ,
17831976 value ,
17841977 chunk_spec ,
17851978 chunk_selection ,
17861979 out_selection ,
17871980 drop_axes ,
17881981 )
17891982
1790- # Phase 4: store (IO)
1983+ # Phase 4: pack + store (IO)
17911984 async with sem :
1792- await layout .store_chunks_async (byte_setter , encoded_chunks , chunk_spec )
1985+ await layout .pack_and_store_async (byte_setter , encoded )
17931986
17941987 await asyncio .gather (
17951988 * [
@@ -1868,24 +2061,33 @@ def write_sync(
18682061 if not batch :
18692062 return
18702063
2064+ assert self .layout is not None
2065+ default_layout = self .layout
2066+
18712067 for bs , chunk_spec , chunk_selection , out_selection , is_complete in batch :
18722068 layout = (
1873- self . layout
1874- if self . layout is not None and chunk_spec .shape == self . layout .chunk_shape
2069+ default_layout
2070+ if chunk_spec .shape == default_layout .chunk_shape
18752071 else self ._get_layout (chunk_spec )
18762072 )
18772073 key = bs .path if hasattr (bs , "path" ) else ""
18782074
2075+ # Phase 1: resolve index
18792076 if is_complete :
1880- index = ShardIndex (key = key )
2077+ index = ShardIndex (key = key , leaf_transform = layout . inner_transform , is_sharded = layout . is_sharded )
18812078 elif layout .is_sharded :
18822079 index = layout .resolve_index (bs , key , chunk_selection = None ) # ALL coords
18832080 else :
18842081 index = layout .resolve_index (bs , key , chunk_selection = chunk_selection )
18852082
1886- existing_chunks = layout .fetch_chunks (bs , index , prototype = chunk_spec .prototype ) if index .chunks else {}
1887- encoded_chunks = layout .merge_and_encode (existing_chunks , value , chunk_spec , chunk_selection , out_selection , drop_axes )
1888- layout .store_chunks_sync (bs , encoded_chunks , chunk_spec )
2083+ # Phase 2: fetch existing
2084+ existing = fetch_chunks_sync (bs , index , prototype = chunk_spec .prototype ) if index .chunks else {}
2085+
2086+ # Phase 3: merge + encode (compute)
2087+ encoded = merge_and_encode_from_index (existing , index , value , chunk_spec , chunk_selection , out_selection , drop_axes )
2088+
2089+ # Phase 4: pack + store
2090+ layout .pack_and_store_sync (bs , encoded )
18892091
18902092
18912093register_pipeline (PhasedCodecPipeline )
0 commit comments