@@ -4622,13 +4622,15 @@ def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
46224622 return result .squeeze ()
46234623
46244624
4625- def permute_dims (arr : NDArray , axes : tuple [int ] | list [int ] | None = None , ** kwargs : Any ) -> NDArray :
4625+ def permute_dims (
4626+ arr : NDArray | np .ndarray , axes : tuple [int ] | list [int ] | None = None , ** kwargs : Any
4627+ ) -> NDArray :
46264628 """
46274629 Permutes the axes (dimensions) of an array.
46284630
46294631 Parameters
46304632 ----------
4631- arr: :ref:`NDArray`
4633+ arr: :ref:`NDArray` | np.ndarray
46324634 The input array.
46334635 axes: tuple[int], list[int], optional
46344636 The desired permutation of axes. If None, the axes are reversed by default.
@@ -4694,6 +4696,8 @@ def permute_dims(arr: NDArray, axes: tuple[int] | list[int] | None = None, **kwa
46944696 """
46954697 if np .isscalar (arr ) or arr .ndim < 2 :
46964698 return arr
4699+ if isinstance (arr , np .ndarray ): # for array-api test compliance (does getitem for comparison)
4700+ return np .permute_dims (arr , axes )
46974701
46984702 ndim = arr .ndim
46994703
@@ -4768,13 +4772,13 @@ def transpose(x, **kwargs: Any) -> NDArray:
47684772 return permute_dims (x , ** kwargs )
47694773
47704774
4771- def matrix_transpose (arr : NDArray , ** kwargs : Any ) -> NDArray :
4775+ def matrix_transpose (arr : NDArray | np . ndarray , ** kwargs : Any ) -> NDArray :
47724776 """
47734777 Transposes a matrix (or a stack of matrices).
47744778
47754779 Parameters
47764780 ----------
4777- arr: :ref:`NDArray`
4781+ arr: :ref:`NDArray` | np.ndarray
47784782 The input NDArray having shape ``(..., M, N)`` and whose innermost two dimensions form
47794783 ``MxN`` matrices.
47804784
0 commit comments