Skip to content

Commit 607ba6d

Browse files
committed
Fix matrix_transpose/permute_dims to accept np.ndarray
1 parent 15310df commit 607ba6d

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

src/blosc2/ndarray.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4622,13 +4622,15 @@ def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
46224622
return result.squeeze()
46234623

46244624

4625-
def permute_dims(arr: NDArray, axes: tuple[int] | list[int] | None = None, **kwargs: Any) -> NDArray:
4625+
def permute_dims(
4626+
arr: NDArray | np.ndarray, axes: tuple[int] | list[int] | None = None, **kwargs: Any
4627+
) -> NDArray:
46264628
"""
46274629
Permutes the axes (dimensions) of an array.
46284630
46294631
Parameters
46304632
----------
4631-
arr: :ref:`NDArray`
4633+
arr: :ref:`NDArray` | np.ndarray
46324634
The input array.
46334635
axes: tuple[int], list[int], optional
46344636
The desired permutation of axes. If None, the axes are reversed by default.
@@ -4694,6 +4696,8 @@ def permute_dims(arr: NDArray, axes: tuple[int] | list[int] | None = None, **kwa
46944696
"""
46954697
if np.isscalar(arr) or arr.ndim < 2:
46964698
return arr
4699+
if isinstance(arr, np.ndarray): # for array-api test compliance (does getitem for comparison)
4700+
return np.permute_dims(arr, axes)
46974701

46984702
ndim = arr.ndim
46994703

@@ -4768,13 +4772,13 @@ def transpose(x, **kwargs: Any) -> NDArray:
47684772
return permute_dims(x, **kwargs)
47694773

47704774

4771-
def matrix_transpose(arr: NDArray, **kwargs: Any) -> NDArray:
4775+
def matrix_transpose(arr: NDArray | np.ndarray, **kwargs: Any) -> NDArray:
47724776
"""
47734777
Transposes a matrix (or a stack of matrices).
47744778
47754779
Parameters
47764780
----------
4777-
arr: :ref:`NDArray`
4781+
arr: :ref:`NDArray` | np.ndarray
47784782
The input NDArray having shape ``(..., M, N)`` and whose innermost two dimensions form
47794783
``MxN`` matrices.
47804784

0 commit comments

Comments
 (0)