@@ -37,14 +37,16 @@ def is_constant(x: NDArray[Any] | types.CSBase, /, *, axis: Literal[0, 1]) -> ND
3737def is_constant (x : types .CupyArray , / , * , axis : Literal [0 , 1 ]) -> types .CupyArray : ...
3838@overload
3939def is_constant (x : types .DaskArray , / , * , axis : Literal [0 , 1 ] | None = None ) -> types .DaskArray : ...
40+ @overload
41+ def is_constant [A : types .HasArrayNamespace ](x : A , / , * , axis : Literal [0 , 1 ] | None = None ) -> bool | A : ...
4042
4143
4244def is_constant (
43- x : NDArray [Any ] | types .CSBase | types .CupyArray | types .DaskArray ,
45+ x : NDArray [Any ] | types .CSBase | types .CupyArray | types .DaskArray | types . HasArrayNamespace ,
4446 / ,
4547 * ,
4648 axis : Literal [0 , 1 ] | None = None ,
47- ) -> bool | NDArray [np .bool ] | types .CupyArray | types .DaskArray :
49+ ) -> bool | NDArray [np .bool ] | types .CupyArray | types .DaskArray | types . HasArrayNamespace :
4850 """Check whether values in array are constant.
4951
5052 Parameters
@@ -90,15 +92,17 @@ def mean(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike |
9092def mean (x : GpuArray , / , * , axis : Literal [0 , 1 ], dtype : DTypeLike | None = None ) -> types .CupyArray : ...
9193@overload
9294def mean (x : types .DaskArray , / , * , axis : Literal [0 , 1 ], dtype : ToDType [Any ] | None = None ) -> types .DaskArray : ...
95+ @overload
96+ def mean [A : types .HasArrayNamespace ](x : A , / , * , axis : Literal [0 , 1 ] | None = None , dtype : DTypeLike | None = None ) -> A : ...
9397
9498
9599def mean (
96- x : CpuArray | GpuArray | DiskArray | types .DaskArray ,
100+ x : CpuArray | GpuArray | DiskArray | types .DaskArray | types . HasArrayNamespace ,
97101 / ,
98102 * ,
99103 axis : Literal [0 , 1 ] | None = None ,
100104 dtype : DTypeLike | None = None ,
101- ) -> NDArray [np .number [Any ]] | types .CupyArray | np .number [Any ] | types .DaskArray :
105+ ) -> NDArray [np .number [Any ]] | types .CupyArray | np .number [Any ] | types .DaskArray | types . HasArrayNamespace :
102106 """Mean over both or one axis.
103107
104108 Parameters
@@ -145,10 +149,10 @@ def mean_var(x: CpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tup
145149def mean_var (x : GpuArray , / , * , axis : Literal [0 , 1 ], correction : int = 0 ) -> tuple [types .CupyArray , types .CupyArray ]: ...
146150@overload
147151def mean_var (x : types .DaskArray , / , * , axis : Literal [0 , 1 ] | None = None , correction : int = 0 ) -> tuple [types .DaskArray , types .DaskArray ]: ...
148-
149-
152+ @ overload
153+ def mean_var [ A : types . HasArrayNamespace ]( x : A , / , * , axis : Literal [ 0 , 1 ] | None = None , correction : int = 0 ) -> tuple [ A , A ]: ...
150154def mean_var (
151- x : CpuArray | GpuArray | types .DaskArray ,
155+ x : CpuArray | GpuArray | types .DaskArray | types . HasArrayNamespace ,
152156 / ,
153157 * ,
154158 axis : Literal [0 , 1 ] | None = None ,
@@ -158,6 +162,7 @@ def mean_var(
158162 | tuple [NDArray [np .float64 ], NDArray [np .float64 ]]
159163 | tuple [types .CupyArray , types .CupyArray ]
160164 | tuple [types .DaskArray , types .DaskArray ]
165+ | tuple [types .HasArrayNamespace , types .HasArrayNamespace ]
161166):
162167 """Mean and variance over both or one axis.
163168
@@ -249,8 +254,10 @@ def min(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ
249254def min (x : GpuArray , / , * , axis : Literal [0 , 1 ], keep_cupy_as_array : bool = False ) -> types .CupyArray : ...
250255@overload
251256def min (x : types .DaskArray , / , * , axis : Literal [0 , 1 ] | None = None , keep_cupy_as_array : bool = False ) -> types .DaskArray : ...
257+ @overload
258+ def min [A : types .HasArrayNamespace ](x : A , / , * , axis : Literal [0 , 1 ] | None = None , keep_cupy_as_array : bool = False ) -> A : ...
252259def min (
253- x : CpuArray | GpuArray | DiskArray | types .DaskArray ,
260+ x : CpuArray | GpuArray | DiskArray | types .DaskArray | types . HasArrayNamespace ,
254261 / ,
255262 * ,
256263 axis : Literal [0 , 1 ] | None = None ,
@@ -304,8 +311,10 @@ def max(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ
304311def max (x : GpuArray , / , * , axis : Literal [0 , 1 ], keep_cupy_as_array : bool = False ) -> types .CupyArray : ...
305312@overload
306313def max (x : types .DaskArray , / , * , axis : Literal [0 , 1 ] | None = None , keep_cupy_as_array : bool = False ) -> types .DaskArray : ...
314+ @overload
315+ def max [A : types .HasArrayNamespace ](x : A , / , * , axis : Literal [0 , 1 ] | None = None , keep_cupy_as_array : bool = False ) -> A : ...
307316def max (
308- x : CpuArray | GpuArray | DiskArray | types .DaskArray ,
317+ x : CpuArray | GpuArray | DiskArray | types .DaskArray | types . HasArrayNamespace ,
309318 / ,
310319 * ,
311320 axis : Literal [0 , 1 ] | None = None ,
0 commit comments