@@ -2288,7 +2288,7 @@ def slice(self, key: int | slice | Sequence[slice], **kwargs: Any) -> NDArray:
22882288 for order , nchunk in enumerate (aligned_chunks ):
22892289 chunk = self .schunk .get_chunk (nchunk )
22902290 newarr .schunk .update_chunk (order , chunk )
2291- newarr .squeeze (mask = mask ) # remove any dummy dims introduced
2291+ newarr .squeeze (axis = np . where ( mask )[ 0 ] ) # remove any dummy dims introduced
22922292 return newarr
22932293
22942294 key = (start , stop )
@@ -2307,11 +2307,11 @@ def slice(self, key: int | slice | Sequence[slice], **kwargs: Any) -> NDArray:
23072307
23082308 return ndslice
23092309
2310- def squeeze (self , mask = None ) -> NDArray :
2310+ def squeeze (self , axis = None ) -> NDArray :
23112311 """Remove single-dimensional entries from the shape of the array.
23122312
23132313 This method modifies the array in-place. If mask is None removes any dimensions with size 1.
2314- If mask is provided, it should be a boolean array of the same shape as the array, and the corresponding
2314+ If axis is provided, it should be an int or tuple of ints and the corresponding
23152315 dimensions (of size 1) will be removed.
23162316
23172317 Returns
@@ -2331,7 +2331,18 @@ def squeeze(self, mask=None) -> NDArray:
23312331 >>> a.shape
23322332 (23, 11)
23332333 """
2334- super ().squeeze (mask = mask )
2334+ if axis is None :
2335+ super ().squeeze ()
2336+ else :
2337+ axis = [axis ] if isinstance (axis , int ) else axis
2338+ mask = [False for i in range (self .ndim )]
2339+ for a in axis :
2340+ if a < 0 :
2341+ a += self .ndim # Adjust axis to be within the array's dimensions
2342+ if mask [a ]:
2343+ raise ValueError ("Axis values must be unique." )
2344+ mask [a ] = True
2345+ super ().squeeze (mask = mask )
23352346 return self
23362347
23372348 def indices (self , order : str | list [str ] | None = None , ** kwargs : Any ) -> NDArray :
@@ -4312,9 +4323,8 @@ def asarray(
43124323 else :
43134324 if not isinstance (array , NDArray ):
43144325 raise ValueError ("Must always do a copy for asarray unless NDArray provided." )
4315- mask = [True ] + [False for i in range (array .ndim )]
43164326 # TODO: make a direct view possible
4317- return blosc2 .expand_dims (array , axis = 0 ).squeeze (mask ) # way to get a view
4327+ return blosc2 .expand_dims (array , axis = 0 ).squeeze (axis = 0 ) # way to get a view
43184328
43194329 return ndarr
43204330
@@ -4515,15 +4525,15 @@ def sort(array: NDArray, order: str | list[str] | None = None, **kwargs: Any) ->
45154525 return larr .sort (order ).compute (** kwargs )
45164526
45174527
4518- def matmul (x1 : NDArray , x2 : NDArray , ** kwargs : Any ) -> NDArray :
4528+ def matmul (x1 : NDArray | np . ndarray , x2 : NDArray , ** kwargs : Any ) -> NDArray | np . ndarray : # noqa : C901
45194529 """
45204530 Computes the matrix product between two Blosc2 NDArrays.
45214531
45224532 Parameters
45234533 ----------
4524- x1: :ref:`NDArray`
4534+ x1: :ref:`NDArray` | np.ndarray
45254535 The first input array.
4526- x2: :ref:`NDArray`
4536+ x2: :ref:`NDArray` | np.ndarray
45274537 The second input array.
45284538 kwargs: Any, optional
45294539 Keyword arguments that are supported by the :func:`empty` constructor.
@@ -4575,51 +4585,70 @@ def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
45754585 array([1, 5])
45764586
45774587 """
4588+ # Added this to pass array-api tests (which use internal getitem to check results)
4589+ if isinstance (x1 , np .ndarray ) and isinstance (x2 , np .ndarray ):
4590+ return np .matmul (x1 , x2 )
45784591
45794592 # Validate arguments are not scalars
45804593 if np .isscalar (x1 ) or np .isscalar (x2 ):
45814594 raise ValueError ("Arguments can't be scalars." )
45824595
4583- # Validate arguments are dimension 1 or 2
4584- if x1 .ndim > 2 or x2 .ndim > 2 :
4585- raise ValueError ("Multiplication of arrays with dimension greater than 2 is not supported yet ." )
4596+ # Validate matrix multiplication compatibility
4597+ if x1 .shape [ - 1 ] != x2 . shape [ builtins . max ( - 2 , - len ( x2 .shape ))] :
4598+ raise ValueError ("Shapes are not aligned for matrix multiplication ." )
45864599
45874600 # Promote 1D arrays to 2D if necessary
45884601 x1_is_vector = False
45894602 x2_is_vector = False
45904603 if x1 .ndim == 1 :
4591- x1 = x1 . reshape (( 1 , x1 . shape [ 0 ]) ) # (N,) -> (1, N)
4604+ x1 = blosc2 . expand_dims ( x1 , axis = 0 ) # (N,) -> (1, N)
45924605 x1_is_vector = True
45934606 if x2 .ndim == 1 :
4594- x2 = x2 . reshape (( x2 . shape [ 0 ], 1 ) ) # (M,) -> (M, 1)
4607+ x2 = blosc2 . expand_dims ( x2 , axis = 1 ) # (M,) -> (M, 1)
45954608 x2_is_vector = True
45964609
4597- # Validate matrix multiplication compatibility
4598- if x1 .shape [- 1 ] != x2 .shape [- 2 ]:
4599- raise ValueError ("Shapes are not aligned for matrix multiplication." )
4600-
46014610 n , k = x1 .shape [- 2 :]
46024611 m = x2 .shape [- 1 ]
4612+ result_shape = np .broadcast_shapes (x1 .shape [:- 2 ], x2 .shape [:- 2 ]) + (n , m )
4613+ result = blosc2 .zeros (result_shape , dtype = np .result_type (x1 , x2 ), ** kwargs )
46034614
4604- result = blosc2 .zeros ((n , m ), dtype = np .result_type (x1 , x2 ), ** kwargs )
4615+ if 0 in result .shape + x1 .shape + x2 .shape : # if any array is empty, return array of 0s
4616+ if x1_is_vector :
4617+ result .squeeze (axis = - 2 )
4618+ if x2_is_vector :
4619+ result .squeeze (axis = - 1 )
4620+ return result
46054621
46064622 p , q = result .chunks [- 2 :]
46074623 r = x2 .chunks [- 1 ]
46084624
4609- for row in range (0 , n , p ):
4610- row_end = builtins .min (row + p , n )
4611- for col in range (0 , m , q ):
4612- col_end = builtins .min (col + q , m )
4613- for aux in range (0 , k , r ):
4614- aux_end = builtins .min (aux + r , k )
4615- bx1 = x1 [row :row_end , aux :aux_end ]
4616- bx2 = x2 [aux :aux_end , col :col_end ]
4617- result [row :row_end , col :col_end ] += np .matmul (bx1 , bx2 )
4625+ intersecting_chunks = get_intersecting_chunks ((), result .shape [:- 2 ], result .chunks [:- 2 ])
4626+ for chunk in intersecting_chunks :
4627+ chunk = chunk .raw
4628+ for row in range (0 , n , p ):
4629+ row_end = builtins .min (row + p , n )
4630+ for col in range (0 , m , q ):
4631+ col_end = builtins .min (col + q , m )
4632+ for aux in range (0 , k , r ):
4633+ aux_end = builtins .min (aux + r , k )
4634+ bx1 = (
4635+ x1 [chunk [- x1 .ndim + 2 :] + (slice (row , row_end ), slice (aux , aux_end ))]
4636+ if x1 .ndim > 2
4637+ else x1 [row :row_end , aux :aux_end ]
4638+ )
4639+ bx2 = (
4640+ x2 [chunk [- x2 .ndim + 2 :] + (slice (aux , aux_end ), slice (col , col_end ))]
4641+ if x2 .ndim > 2
4642+ else x2 [aux :aux_end , col :col_end ]
4643+ )
4644+ result [chunk + (slice (row , row_end ), slice (col , col_end ))] += np .matmul (bx1 , bx2 )
46184645
4619- if x1_is_vector and x2_is_vector :
4620- return result [0 ][0 ]
4646+ if x1_is_vector :
4647+ result .squeeze (axis = - 2 )
4648+ if x2_is_vector :
4649+ result .squeeze (axis = - 1 )
46214650
4622- return result . squeeze ()
4651+ return result
46234652
46244653
46254654def permute_dims (
@@ -5178,6 +5207,16 @@ def _get_local_slice(prior_selection, post_selection, chunk_bounds):
51785207 return locbegin , locend
51795208
51805209
5210+ def get_intersecting_chunks (_slice , shape , chunks ):
5211+ if 0 not in chunks :
5212+ chunk_size = ndindex .ChunkSize (chunks )
5213+ return chunk_size .as_subchunks (_slice , shape ) # if _slice is (), returns all chunks
5214+ else :
5215+ return (
5216+ ndindex .ndindex (...).expand (shape ),
5217+ ) # chunk is whole array so just return full tuple to do loop once
5218+
5219+
51815220def broadcast_to (arr , shape ):
51825221 """
51835222 Broadcast an array to a new shape.
0 commit comments