-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy path_power.py
More file actions
43 lines (28 loc) · 1.63 KB
/
_power.py
File metadata and controls
43 lines (28 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations
from functools import singledispatch
from typing import TYPE_CHECKING
import numpy as np
from .. import types
if TYPE_CHECKING:
from numpy.typing import DTypeLike
from fast_array_utils.typing import CpuArray, GpuArray
# All supported array types except for disk ones and CSDataset
type Array = CpuArray | GpuArray | types.DaskArray
def power[Arr: Array](x: Arr, n: int, /, dtype: DTypeLike | None = None) -> Arr:
"""Take array or matrix to a power."""
# This wrapper is necessary because TypeVars can’t be used in `singledispatch` functions
return _power(x, n, dtype=dtype) # type: ignore[return-value]
@singledispatch
def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array:
if TYPE_CHECKING:
assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix)
return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator]
@_power.register(types.CSBase | types.CupyCSMatrix)
def _power_cs[Mat: types.CSBase | types.CupyCSMatrix](x: Mat, n: int, /, dtype: DTypeLike | None = None) -> Mat:
new_data = power(x.data, n, dtype=dtype)
return type(x)((new_data, x.indices, x.indptr), shape=x.shape, dtype=new_data.dtype) # type: ignore[call-overload,return-value]
@_power.register(types.DaskArray)
def _power_dask(x: types.DaskArray, n: int, /, dtype: DTypeLike | None = None) -> types.DaskArray:
meta = x._meta.astype(dtype or x.dtype) # noqa: SLF001
return x.map_blocks(lambda c: power(c, n, dtype=dtype), dtype=dtype, meta=meta) # type: ignore[type-var,arg-type]