forked from scverse/fast-array-utils
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_generic_ops.py
More file actions
127 lines (102 loc) · 4.35 KB
/
_generic_ops.py
File metadata and controls
127 lines (102 loc) · 4.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
from functools import singledispatch
from typing import TYPE_CHECKING, cast, get_args
import numpy as np
from .. import types
from ._typing import DtypeOps
from ._utils import _dask_inner, _dtype_kw
if TYPE_CHECKING:
from typing import Any, Literal
from numpy.typing import DTypeLike, NDArray
from ..typing import CpuArray, DiskArray, GpuArray
from ._typing import Ops
type ComplexAxis = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None
@singledispatch
def generic_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> Any: # Fallback handles arbitrary array-api-compatible types, so return type can't be narrowed # noqa: ANN401
del keep_cupy_as_array
if TYPE_CHECKING:
# these are never passed to this fallback function, but `singledispatch` wants them
assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix)
# np supports these, but doesn’t know it. (TODO: test cupy)
assert not isinstance(x, types.ZarrArray | types.H5Dataset)
# Catch array-api-compat-wrapped types that lack __array_namespace__ (i.e. PyTorch)
import array_api_compat
if array_api_compat.is_array_api_obj(x):
xp = array_api_compat.array_namespace(x)
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op))
arr = getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr
@generic_op.register(types.HasArrayNamespace)
def _generic_op_array_api(
x: types.HasArrayNamespace,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> Any: # noqa: ANN401
"""Handle arrays with native array API support."""
del keep_cupy_as_array
import array_api_compat
xp = array_api_compat.array_namespace(x)
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op))
@generic_op.register(types.CupyArray | types.CupyCSMatrix)
def _generic_op_cupy(
x: GpuArray,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.CupyArray | np.number[Any]:
arr = cast("types.CupyArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze()
@generic_op.register(types.CSBase)
def _generic_op_cs(
x: types.CSBase,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any]:
del keep_cupy_as_array
import scipy.sparse as sp
# TODO(flying-sheep): once scipy fixes this issue, instead of all this,
# just convert to sparse array, then `return x.{op}(dtype=dtype)`
# https://github.com/scipy/scipy/issues/23768
if axis is None:
return cast("np.number[Any]", getattr(x.data, op)(**_dtype_kw(dtype, op)))
if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true
assert isinstance(dtype, np.dtype | type | None)
# convert to array so dimensions collapse as expected
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type]
rv = cast("NDArray[Any] | types.coo_array | np.number[Any]", getattr(x, op)(axis=axis))
# old scipy versions’ sparray.{max,min}() return a 1×n/n×1 sparray here, so we squeeze
return rv.toarray().squeeze() if isinstance(rv, types.coo_array) else rv
@generic_op.register(types.DaskArray)
def _generic_op_dask(
x: types.DaskArray,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.DaskArray:
if op in get_args(DtypeOps) and dtype is None:
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
dtype = getattr(np, op)(np.zeros(1, dtype=x.dtype)).dtype
return _dask_inner(x, op, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array)