-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy path_sum.py
More file actions
121 lines (95 loc) · 3.43 KB
/
_sum.py
File metadata and controls
121 lines (95 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, overload
import numpy as np
from .._import import lazy_singledispatch
if TYPE_CHECKING:
from typing import Any, Literal
from numpy.typing import ArrayLike, DTypeLike, NDArray
from .. import types
@overload
def sum(
x: ArrayLike, /, *, axis: None = None, dtype: DTypeLike | None = None
) -> np.number[Any]: ...
@overload
def sum(
x: ArrayLike, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
) -> NDArray[Any]: ...
@overload
def sum(
x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None
) -> types.DaskArray: ...
def sum(
x: ArrayLike | types.DaskArray,
/,
*,
axis: Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
) -> NDArray[Any] | np.number[Any] | types.DaskArray:
"""Sum over both or one axis.
Returns
-------
If ``axis`` is :data:`None`, then the sum over all elements is returned as a scalar.
Otherwise, the sum over the given axis is returned as a 1D array.
See Also
--------
:func:`numpy.sum`
"""
return _sum(x, axis=axis, dtype=dtype)
@lazy_singledispatch
def _sum(
x: ArrayLike, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> NDArray[Any] | np.number[Any] | types.DaskArray:
return np.sum(x, axis=axis, dtype=dtype) # type: ignore[no-any-return]
@_sum.register("fast_array_utils.types:CSBase", "scipy.sparse")
def _(
x: types.CSBase, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> NDArray[Any] | np.number[Any]:
import scipy.sparse as sp
from ..types import CSMatrix
if isinstance(x, CSMatrix):
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)
return np.sum(x, axis=axis, dtype=dtype) # type: ignore[no-any-return]
@_sum.register("dask.array:Array")
def _(
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> types.DaskArray:
if TYPE_CHECKING:
from dask.array.reductions import reduction
else:
from dask.array import reduction
if isinstance(x._meta, np.matrix): # pragma: no cover # noqa: SLF001
msg = "sum does not support numpy matrices"
raise TypeError(msg)
def sum_drop_keepdims(
a: NDArray[Any] | types.CSBase,
/,
*,
axis: tuple[Literal[0], Literal[1]] | Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keepdims: bool = False,
) -> NDArray[Any]:
del keepdims
match axis:
case (0 | 1 as n,):
axis = n
case (0, 1) | (1, 0):
axis = None
case tuple(): # pragma: no cover
msg = f"`sum` can only sum over `axis=0|1|(0,1)` but got {axis} instead"
raise ValueError(msg)
rv = sum(a, axis=axis, dtype=dtype)
rv = np.array(rv, ndmin=1) # make sure rv is at least 1D
return rv.reshape((1, len(rv)))
if dtype is None:
# Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`)
dtype = np.zeros(1, dtype=x.dtype).sum().dtype
return reduction( # type: ignore[no-any-return,no-untyped-call]
x,
sum_drop_keepdims,
partial(np.sum, dtype=dtype),
axis=axis,
dtype=dtype,
meta=np.array([], dtype=dtype),
)