-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy path_generic_ops.py
More file actions
111 lines (90 loc) · 3.94 KB
/
_generic_ops.py
File metadata and controls
111 lines (90 loc) · 3.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
from functools import singledispatch
from typing import TYPE_CHECKING, cast, get_args
import numpy as np
from .. import types
from ._typing import DtypeOps
from ._utils import _dask_inner, _dtype_kw
if TYPE_CHECKING:
from typing import Any, Literal
from numpy.typing import DTypeLike, NDArray
from ..typing import CpuArray, DiskArray, GpuArray
from ._typing import Ops
type ComplexAxis = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None
def _run_numpy_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
arr = cast("NDArray[Any] | np.number[Any] | types.CupyArray | types.CupyCOOMatrix | types.DaskArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr
@singledispatch
def generic_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
del keep_cupy_as_array
if TYPE_CHECKING:
# these are never passed to this fallback function, but `singledispatch` wants them
assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix)
# np supports these, but doesn’t know it. (TODO: test cupy)
assert not isinstance(x, types.ZarrArray | types.H5Dataset)
return cast("NDArray[Any] | np.number[Any]", _run_numpy_op(x, op, axis=axis, dtype=dtype))
@generic_op.register(types.CupyArray | types.CupyCSMatrix)
def _generic_op_cupy(
x: GpuArray,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.CupyArray | np.number[Any]:
arr = cast("types.CupyArray", _run_numpy_op(x, op, axis=axis, dtype=dtype))
return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze()
@generic_op.register(types.CSBase)
def _generic_op_cs(
x: types.CSBase,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any]:
del keep_cupy_as_array
import scipy.sparse as sp
# TODO(flying-sheep): once scipy fixes this issue, instead of all this,
# just convert to sparse array, then `return x.{op}(dtype=dtype)`
# https://github.com/scipy/scipy/issues/23768
if axis is None:
return cast("np.number[Any]", getattr(x.data, op)(**_dtype_kw(dtype, op)))
if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true
assert isinstance(dtype, np.dtype | type | None)
# convert to array so dimensions collapse as expected
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type]
rv = cast("NDArray[Any] | types.coo_array | np.number[Any]", getattr(x, op)(axis=axis))
# old scipy versions’ sparray.{max,min}() return a 1×n/n×1 sparray here, so we squeeze
return rv.toarray().squeeze() if isinstance(rv, types.coo_array) else rv
@generic_op.register(types.DaskArray)
def _generic_op_dask(
x: types.DaskArray,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.DaskArray:
if op in get_args(DtypeOps) and dtype is None:
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
dtype = getattr(np, op)(np.zeros(1, dtype=x.dtype)).dtype
return _dask_inner(x, op, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array)