@@ -32,15 +32,7 @@ def generic_op(
3232 dtype : DTypeLike | None = None ,
3333 keep_cupy_as_array : bool = False ,
3434) -> NDArray [Any ] | np .number [Any ] | types .CupyArray | types .DaskArray :
35- del keep_cupy_as_array
36- if TYPE_CHECKING :
37- # these are never passed to this fallback function, but `singledispatch` wants them
38- assert not isinstance (x , types .CSBase | types .DaskArray | types .CupyArray | types .CupyCSMatrix )
39- # np supports these, but doesn’t know it. (TODO: test cupy)
40- assert not isinstance (x , types .ZarrArray | types .H5Dataset )
41-
42- arr = getattr (np , op )(x , axis = axis , ** _dtype_kw (dtype , op ))
43- return arr .toarray () if isinstance (arr , types .CupyCOOMatrix ) else arr
35+ raise NotImplementedError
4436
4537
4638@generic_op .register (np .ndarray )
@@ -87,7 +79,8 @@ def _generic_op_cupy(
8779 dtype : DTypeLike | None = None ,
8880 keep_cupy_as_array : bool = False ,
8981) -> types .CupyArray | np .number [Any ]:
90- arr = cast ("types.CupyArray" , getattr (np , op )(x , axis = axis , ** _dtype_kw (dtype , op )))
82+ arr = cast ("types.CupyArray | types.CupyCOOMatrix" , getattr (np , op )(x , axis = axis , ** _dtype_kw (dtype , op )))
83+ arr = arr .toarray () if isinstance (arr , types .CupyCOOMatrix ) else arr
9184 return cast ("np.number[Any]" , arr .get ()[()]) if not keep_cupy_as_array and axis is None else arr .squeeze ()
9285
9386
0 commit comments