@@ -30,7 +30,7 @@ class DiscontiguousArrayError(Exception):
3030 pass
3131
3232
33- class CollapsedDimensionError (Exception ):
33+ class UnsupportedVIndexingError (Exception ):
3434 pass
3535
3636
@@ -160,7 +160,7 @@ def make_chunk_info_for_rust_with_indices(
160160 drop_axes : tuple [int , ...],
161161 shape : tuple [int , ...],
162162) -> RustChunkInfo :
163- shape = shape if shape else ( 1 ,) # constant array
163+ is_constant = shape == ()
164164 chunk_info_with_indices : list [ChunkItem ] = []
165165 write_empty_chunks : bool = True
166166 for (
@@ -182,17 +182,39 @@ def make_chunk_info_for_rust_with_indices(
182182 shape_chunk_selection = get_shape_for_selector (
183183 chunk_selection , chunk_spec .shape , pad = True , drop_axes = drop_axes
184184 )
185- if prod_op (shape_chunk_selection ) != prod_op (shape_chunk_selection_slices ):
186- raise CollapsedDimensionError (
185+ if ( chunk_size := prod_op (shape_chunk_selection ) ) != prod_op (shape_chunk_selection_slices ):
186+ raise UnsupportedVIndexingError (
187187 f"{ shape_chunk_selection } != { shape_chunk_selection_slices } "
188188 )
189+ if not is_constant and chunk_size > prod_op (shape ):
190+ raise IndexError (
191+ f"the size of the chunk subset { chunk_size } and input/output subset { prod_op (shape )} are incompatible"
192+ )
193+ # We need to have io_array_shape and out_selection_expanded with dimensionalities matching that of the underlying array.
194+ # So if we detect that a dimension has been dropped (due to a singleton axis) when converting to slices, we update these two values.
195+ if not is_constant and len (shape_chunk_selection ) != len (shape_chunk_selection_slices ):
196+ shape_ctr = 0
197+ io_array_shape = []
198+ out_selection_expanded = []
199+ for shape_chunk in shape_chunk_selection_slices :
200+ # Append 1/size-1 slice if this dimension has been dropped on the io_array i.e., shape_chunk_selection has been exhausted so there is an extra 1-sized dimension at the end or has a mismatch with the "full" chunk shape `shape_chunk_selection_slices`.
201+ if shape_chunk == 1 and (shape_ctr >= len (shape_chunk_selection ) or shape_chunk != shape_chunk_selection [shape_ctr ]):
202+ io_array_shape += [1 ]
203+ out_selection_expanded += [slice (0 , 1 )]
204+ else :
205+ io_array_shape += [shape [shape_ctr ]]
206+ out_selection_expanded += [out_selection_as_slices [shape_ctr ]]
207+ shape_ctr += 1
208+ else :
209+ io_array_shape = shape
210+ out_selection_expanded = out_selection_as_slices
189211 chunk_info_with_indices .append (
190212 ChunkItem (
191213 key = byte_getter .path ,
192214 chunk_subset = chunk_selection_as_slices ,
193215 chunk_shape = chunk_spec .shape ,
194- subset = out_selection_as_slices ,
195- shape = shape ,
216+ subset = out_selection_expanded ,
217+ shape = io_array_shape ,
196218 )
197219 )
198220 return RustChunkInfo (chunk_info_with_indices , write_empty_chunks )
0 commit comments