@@ -37,14 +37,16 @@ def is_constant(x: NDArray[Any] | types.CSBase, /, *, axis: Literal[0, 1]) -> ND
3737def is_constant (x : types .CupyArray , / , * , axis : Literal [0 , 1 ]) -> types .CupyArray : ...
3838@overload
3939def is_constant (x : types .DaskArray , / , * , axis : Literal [0 , 1 ] | None = None ) -> types .DaskArray : ...
40+ @overload
41+ def is_constant [A : types .HasArrayNamespace ](x : A , / , * , axis : Literal [0 , 1 ] | None = None ) -> bool | A : ...
4042
4143
4244def is_constant (
43- x : NDArray [Any ] | types .CSBase | types .CupyArray | types .DaskArray ,
45+ x : NDArray [Any ] | types .CSBase | types .CupyArray | types .DaskArray | types . HasArrayNamespace ,
4446 / ,
4547 * ,
4648 axis : Literal [0 , 1 ] | None = None ,
47- ) -> bool | NDArray [np .bool ] | types .CupyArray | types .DaskArray :
49+ ) -> bool | NDArray [np .bool ] | types .CupyArray | types .DaskArray | types . HasArrayNamespace :
4850 """Check whether values in array are constant.
4951
5052 Parameters
@@ -90,15 +92,17 @@ def mean(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike |
9092def mean (x : GpuArray , / , * , axis : Literal [0 , 1 ], dtype : DTypeLike | None = None ) -> types .CupyArray : ...
9193@overload
9294def mean (x : types .DaskArray , / , * , axis : Literal [0 , 1 ], dtype : ToDType [Any ] | None = None ) -> types .DaskArray : ...
95+ @overload
96+ def mean [A : types .HasArrayNamespace ](x : A , / , * , axis : Literal [0 , 1 ] | None = None , dtype : DTypeLike | None = None ) -> A : ...
9397
9498
9599def mean (
96- x : CpuArray | GpuArray | DiskArray | types .DaskArray ,
100+ x : CpuArray | GpuArray | DiskArray | types .DaskArray | types . HasArrayNamespace ,
97101 / ,
98102 * ,
99103 axis : Literal [0 , 1 ] | None = None ,
100104 dtype : DTypeLike | None = None ,
101- ) -> NDArray [np .number [Any ]] | types .CupyArray | np .number [Any ] | types .DaskArray :
105+ ) -> NDArray [np .number [Any ]] | types .CupyArray | np .number [Any ] | types .DaskArray | types . HasArrayNamespace :
102106 """Mean over both or one axis.
103107
104108 Parameters
@@ -145,10 +149,10 @@ def mean_var(x: CpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tup
145149def mean_var (x : GpuArray , / , * , axis : Literal [0 , 1 ], correction : int = 0 ) -> tuple [types .CupyArray , types .CupyArray ]: ...
146150@overload
147151def mean_var (x : types .DaskArray , / , * , axis : Literal [0 , 1 ] | None = None , correction : int = 0 ) -> tuple [types .DaskArray , types .DaskArray ]: ...
148-
149-
152+ @ overload
153+ def mean_var [ A : types . HasArrayNamespace ]( x : A , / , * , axis : Literal [ 0 , 1 ] | None = None , correction : int = 0 ) -> tuple [ A , A ]: ...
150154def mean_var (
151- x : CpuArray | GpuArray | types .DaskArray ,
155+ x : CpuArray | GpuArray | types .DaskArray | types . HasArrayNamespace ,
152156 / ,
153157 * ,
154158 axis : Literal [0 , 1 ] | None = None ,
@@ -158,6 +162,7 @@ def mean_var(
158162 | tuple [NDArray [np .float64 ], NDArray [np .float64 ]]
159163 | tuple [types .CupyArray , types .CupyArray ]
160164 | tuple [types .DaskArray , types .DaskArray ]
165+ | tuple [types .HasArrayNamespace , types .HasArrayNamespace ]
161166):
162167 """Mean and variance over both or one axis.
163168
@@ -214,13 +219,13 @@ def _mk_generic_op(op: DtypeOps) -> StatFunDtype: ...
214219# https://github.com/scverse/fast-array-utils/issues/52
215220def _mk_generic_op (op : Ops ) -> StatFunNoDtype | StatFunDtype :
216221 def _generic_op (
217- x : CpuArray | GpuArray | DiskArray | types .DaskArray ,
222+ x : CpuArray | GpuArray | DiskArray | types .DaskArray | types . HasArrayNamespace ,
218223 / ,
219224 * ,
220225 axis : Literal [0 , 1 ] | None = None ,
221226 dtype : DTypeLike | None = None ,
222227 keep_cupy_as_array : bool = False ,
223- ) -> NDArray [Any ] | np .number [Any ] | types .CupyArray | types .DaskArray :
228+ ) -> NDArray [Any ] | np .number [Any ] | types .CupyArray | types .DaskArray | types . HasArrayNamespace :
224229 from ._generic_ops import generic_op
225230
226231 assert dtype is None or op in get_args (DtypeOps ), f"`dtype` is not supported for operation { op !r} "
@@ -249,8 +254,10 @@ def min(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ
249254def min (x : GpuArray , / , * , axis : Literal [0 , 1 ], keep_cupy_as_array : bool = False ) -> types .CupyArray : ...
250255@overload
251256def min (x : types .DaskArray , / , * , axis : Literal [0 , 1 ] | None = None , keep_cupy_as_array : bool = False ) -> types .DaskArray : ...
257+ @overload
258+ def min [A : types .HasArrayNamespace ](x : A , / , * , axis : Literal [0 , 1 ] | None = None , keep_cupy_as_array : bool = False ) -> A : ...
252259def min (
253- x : CpuArray | GpuArray | DiskArray | types .DaskArray ,
260+ x : CpuArray | GpuArray | DiskArray | types .DaskArray | types . HasArrayNamespace ,
254261 / ,
255262 * ,
256263 axis : Literal [0 , 1 ] | None = None ,
@@ -304,8 +311,10 @@ def max(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ
304311def max (x : GpuArray , / , * , axis : Literal [0 , 1 ], keep_cupy_as_array : bool = False ) -> types .CupyArray : ...
305312@overload
306313def max (x : types .DaskArray , / , * , axis : Literal [0 , 1 ] | None = None , keep_cupy_as_array : bool = False ) -> types .DaskArray : ...
314+ @overload
315+ def max [A : types .HasArrayNamespace ](x : A , / , * , axis : Literal [0 , 1 ] | None = None , keep_cupy_as_array : bool = False ) -> A : ...
307316def max (
308- x : CpuArray | GpuArray | DiskArray | types .DaskArray ,
317+ x : CpuArray | GpuArray | DiskArray | types .DaskArray | types . HasArrayNamespace ,
309318 / ,
310319 * ,
311320 axis : Literal [0 , 1 ] | None = None ,
@@ -359,14 +368,16 @@ def sum(x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, ke
359368def sum (x : GpuArray , / , * , axis : Literal [0 , 1 ], dtype : DTypeLike | None = None , keep_cupy_as_array : bool = False ) -> types .CupyArray : ...
360369@overload
361370def sum (x : types .DaskArray , / , * , axis : Literal [0 , 1 ] | None = None , dtype : DTypeLike | None = None , keep_cupy_as_array : bool = False ) -> types .DaskArray : ...
371+ @overload
372+ def sum [A : types .HasArrayNamespace ](x : A , / , * , axis : Literal [0 , 1 ] | None = None , dtype : DTypeLike | None = None , keep_cupy_as_array : bool = False ) -> A : ...
362373def sum (
363- x : CpuArray | GpuArray | DiskArray | types .DaskArray ,
374+ x : CpuArray | GpuArray | DiskArray | types .DaskArray | types . HasArrayNamespace ,
364375 / ,
365376 * ,
366377 axis : Literal [0 , 1 ] | None = None ,
367378 dtype : DTypeLike | None = None ,
368379 keep_cupy_as_array : bool = False ,
369- ) -> NDArray [Any ] | types .CupyArray | np .number [Any ] | types .DaskArray :
380+ ) -> NDArray [Any ] | types .CupyArray | np .number [Any ] | types .DaskArray | types . HasArrayNamespace :
370381 """Sum over both or one axis.
371382
372383 Parameters
0 commit comments