Skip to content

Commit c77a1bc

Browse files
committed
types for others
1 parent 3d4ee3a commit c77a1bc

2 files changed

Lines changed: 24 additions & 11 deletions

File tree

src/fast_array_utils/conv/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,19 @@ def to_dense(x: types.DaskArray, /, *, order: Literal["K", "A", "C", "F"] = "K",
3636
def to_dense(x: GpuArray | types.CupySpMatrix, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[False] = False) -> types.CupyArray: ...
3737
@overload
3838
def to_dense(x: GpuArray | types.CupySpMatrix, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[True]) -> NDArray[Any]: ...
39+
@overload
40+
def to_dense[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[False] = False) -> A: ...
41+
@overload
42+
def to_dense[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[True]) -> NDArray[Any]: ...
3943

4044

4145
def to_dense(
42-
x: CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix,
46+
x: CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix | types.HasArrayNamespace,
4347
/,
4448
*,
4549
order: Literal["K", "A", "C", "F"] = "K",
4650
to_cpu_memory: bool = False,
47-
) -> NDArray[Any] | types.DaskArray | types.CupyArray:
51+
) -> NDArray[Any] | types.DaskArray | types.CupyArray | types.HasArrayNamespace:
4852
r"""Convert x to a dense array.
4953
5054
If ``to_cpu_memory`` is :data:`False`, :class:`dask.array.Array`\ s and

src/fast_array_utils/stats/__init__.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@ def is_constant(x: NDArray[Any] | types.CSBase, /, *, axis: Literal[0, 1]) -> ND
3737
def is_constant(x: types.CupyArray, /, *, axis: Literal[0, 1]) -> types.CupyArray: ...
3838
@overload
3939
def is_constant(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None) -> types.DaskArray: ...
40+
@overload
41+
def is_constant[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None) -> bool | A: ...
4042

4143

4244
def is_constant(
43-
x: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray,
45+
x: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray | types.HasArrayNamespace,
4446
/,
4547
*,
4648
axis: Literal[0, 1] | None = None,
47-
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray:
49+
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray | types.HasArrayNamespace:
4850
"""Check whether values in array are constant.
4951
5052
Parameters
@@ -90,15 +92,17 @@ def mean(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike |
9092
def mean(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> types.CupyArray: ...
9193
@overload
9294
def mean(x: types.DaskArray, /, *, axis: Literal[0, 1], dtype: ToDType[Any] | None = None) -> types.DaskArray: ...
95+
@overload
96+
def mean[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None) -> A: ...
9397

9498

9599
def mean(
96-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
100+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
97101
/,
98102
*,
99103
axis: Literal[0, 1] | None = None,
100104
dtype: DTypeLike | None = None,
101-
) -> NDArray[np.number[Any]] | types.CupyArray | np.number[Any] | types.DaskArray:
105+
) -> NDArray[np.number[Any]] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace:
102106
"""Mean over both or one axis.
103107
104108
Parameters
@@ -145,10 +149,10 @@ def mean_var(x: CpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tup
145149
def mean_var(x: GpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tuple[types.CupyArray, types.CupyArray]: ...
146150
@overload
147151
def mean_var(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, correction: int = 0) -> tuple[types.DaskArray, types.DaskArray]: ...
148-
149-
152+
@overload
153+
def mean_var[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, correction: int = 0) -> tuple[A, A]: ...
150154
def mean_var(
151-
x: CpuArray | GpuArray | types.DaskArray,
155+
x: CpuArray | GpuArray | types.DaskArray | types.HasArrayNamespace,
152156
/,
153157
*,
154158
axis: Literal[0, 1] | None = None,
@@ -158,6 +162,7 @@ def mean_var(
158162
| tuple[NDArray[np.float64], NDArray[np.float64]]
159163
| tuple[types.CupyArray, types.CupyArray]
160164
| tuple[types.DaskArray, types.DaskArray]
165+
| tuple[types.HasArrayNamespace, types.HasArrayNamespace]
161166
):
162167
"""Mean and variance over both or one axis.
163168
@@ -249,8 +254,10 @@ def min(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ
249254
def min(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ...
250255
@overload
251256
def min(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ...
257+
@overload
258+
def min[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> A: ...
252259
def min(
253-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
260+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
254261
/,
255262
*,
256263
axis: Literal[0, 1] | None = None,
@@ -304,8 +311,10 @@ def max(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ
304311
def max(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ...
305312
@overload
306313
def max(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ...
314+
@overload
315+
def max[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> A: ...
307316
def max(
308-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
317+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
309318
/,
310319
*,
311320
axis: Literal[0, 1] | None = None,

0 commit comments

Comments
 (0)