forked from scverse/fast-array-utils
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_generic_ops.py
More file actions
137 lines (111 loc) · 4.56 KB
/
_generic_ops.py
File metadata and controls
137 lines (111 loc) · 4.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
from functools import singledispatch
from typing import TYPE_CHECKING, cast, get_args
import numpy as np
from .. import types
from ._typing import DtypeOps
from ._utils import _dask_inner, _dtype_kw
if TYPE_CHECKING:
from typing import Any, Literal
from numpy.typing import DTypeLike, NDArray
from ..typing import CpuArray, DiskArray, GpuArray
from ._typing import Ops
type ComplexAxis = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None
@singledispatch
def generic_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
del keep_cupy_as_array
if TYPE_CHECKING:
# these are never passed to this fallback function, but `singledispatch` wants them
assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix)
# np supports these, but doesn’t know it. (TODO: test cupy)
assert not isinstance(x, types.ZarrArray | types.H5Dataset)
arr = getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr
@generic_op.register(np.ndarray)
# register explicitly to avoid the array API path and performance slow down
def _generic_op_numpy(
x: np.ndarray,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any]:
del keep_cupy_as_array
arr = getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr # type: ignore[return-value]
@generic_op.register(types.HasArrayNamespace)
def _generic_op_array_api(
x: types.HasArrayNamespace,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> Any: # noqa: ANN401
"""Handle arrays with native array API support."""
del keep_cupy_as_array
import array_api_compat
xp = array_api_compat.array_namespace(x)
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op))
@generic_op.register(types.CupyArray | types.CupyCSMatrix)
def _generic_op_cupy(
x: GpuArray,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.CupyArray | np.number[Any]:
arr = cast("types.CupyArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze()
@generic_op.register(types.CSBase)
def _generic_op_cs(
x: types.CSBase,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any]:
del keep_cupy_as_array
import scipy.sparse as sp
# TODO(flying-sheep): once scipy fixes this issue, instead of all this,
# just convert to sparse array, then `return x.{op}(dtype=dtype)`
# https://github.com/scipy/scipy/issues/23768
if axis is None:
return cast("np.number[Any]", getattr(x.data, op)(**_dtype_kw(dtype, op)))
if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true
assert isinstance(dtype, np.dtype | type | None)
# convert to array so dimensions collapse as expected
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type,call-overload]
rv = cast("NDArray[Any] | types.coo_array | np.number[Any]", getattr(x, op)(axis=axis))
# old scipy versions’ sparray.{max,min}() return a 1×n/n×1 sparray here, so we squeeze
return rv.toarray().squeeze() if isinstance(rv, types.coo_array) else rv
@generic_op.register(types.DaskArray)
def _generic_op_dask(
x: types.DaskArray,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.DaskArray:
if op in get_args(DtypeOps) and dtype is None:
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
dtype = getattr(np, op)(np.zeros(1, dtype=x.dtype)).dtype
return _dask_inner(x, op, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array)