Skip to content

Commit 49283b0

Browse files
committed
fix cupy tests
1 parent c890bc3 commit 49283b0

2 files changed

Lines changed: 5 additions & 12 deletions

File tree

src/fast_array_utils/stats/_generic_ops.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,7 @@ def generic_op(
3232
dtype: DTypeLike | None = None,
3333
keep_cupy_as_array: bool = False,
3434
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
35-
del keep_cupy_as_array
36-
if TYPE_CHECKING:
37-
# these are never passed to this fallback function, but `singledispatch` wants them
38-
assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix)
39-
# np supports these, but doesn’t know it. (TODO: test cupy)
40-
assert not isinstance(x, types.ZarrArray | types.H5Dataset)
41-
42-
arr = getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op))
43-
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr
35+
raise NotImplementedError
4436

4537

4638
@generic_op.register(np.ndarray)
@@ -87,7 +79,8 @@ def _generic_op_cupy(
8779
dtype: DTypeLike | None = None,
8880
keep_cupy_as_array: bool = False,
8981
) -> types.CupyArray | np.number[Any]:
90-
arr = cast("types.CupyArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
82+
arr = cast("types.CupyArray | types.CupyCOOMatrix", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
83+
arr = arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr
9184
return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze()
9285

9386

src/fast_array_utils/stats/_power.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array:
2929
raise NotImplementedError
3030

3131

32-
@_power.register(np.ndarray)
33-
def _power_numpy(x: np.ndarray, n: int, /, dtype: DTypeLike | None = None) -> np.ndarray:
32+
@_power.register(np.ndarray | types.CupyArray)
33+
def _power_numpy_cupy(x: np.ndarray, n: int, /, dtype: DTypeLike | None = None) -> np.ndarray:
3434
# avoids slower xp.pow(xp.astype(...)) path
3535
return x**n if dtype is None else np.power(x, n, dtype=dtype)
3636

0 commit comments

Comments
 (0)