@@ -103,9 +103,11 @@ def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]:
103103 return np_arr
104104
105105
106- def to_np_dense_checked (
107- stat : NDArray [DTypeOut ] | np .number [Any ] | types .DaskArray , axis : Literal [0 , 1 ] | None , arr : CpuArray | GpuArray | DiskArray | types .DaskArray
108- ) -> NDArray [DTypeOut ] | np .number [Any ]:
106+ def to_np_dense_checked [DT : DTypeOut ](
107+ stat : NDArray [DT ] | np .number [Any ] | types .DaskArray | types .HasArrayNamespace ,
108+ axis : Literal [0 , 1 ] | None ,
109+ arr : CpuArray | GpuArray | DiskArray | types .COOBase | types .DaskArray | types .HasArrayNamespace ,
110+ ) -> NDArray [DT ] | np .number [Any ]:
109111 match axis , arr :
110112 case _, types .DaskArray ():
111113 assert isinstance (stat , types .DaskArray ), type (stat )
@@ -208,7 +210,7 @@ def test_min_max(array_type: ArrayType[CpuArray | GpuArray | DiskArray | types.D
208210 np_arr = rng .random ((100 , 100 ))
209211 arr = array_type (np_arr )
210212
211- result = to_np_dense_checked (func (arr , axis = axis ), axis , arr )
213+ result = to_np_dense_checked (func (arr , axis = axis ), axis , arr ) # type: ignore[arg-type]
212214
213215 expected = (np .min if func is stats .min else np .max )(np_arr , axis = axis )
214216 np .testing .assert_array_equal (result , expected )
@@ -229,7 +231,7 @@ def test_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1]
229231 np_arr = np .array (data , dtype = np .float32 )
230232 arr = array_type (np_arr )
231233 assert 1 in arr .chunksize , "This test is supposed to test 1×n and n×1 chunk sizes"
232- stat = cast ("NDArray[Any] | types.CupyArray" , func (arr , axis = axis ).compute ())
234+ stat = cast ("NDArray[Any] | types.CupyArray" , func (arr , axis = axis ).compute ()) # type: ignore[union-attr]
233235 if isinstance (stat , types .CupyArray ):
234236 stat = stat .get ()
235237 np_func = getattr (np , func .__name__ )
@@ -321,6 +323,8 @@ def test_mean_var_pbmc_dask(array_type: ArrayType[types.DaskArray], pbmc64k_redu
321323 arr = array_type (mat )
322324
323325 mean_mat , var_mat = stats .mean_var (mat , axis = 0 , correction = 1 )
326+ mean_arr : NDArray [Any ] | np .number # actually just NDArray, and mypy should be able to infer.
327+ var_arr : NDArray [Any ] | np .number
324328 mean_arr , var_arr = (to_np_dense_checked (a , 0 , arr ) for a in stats .mean_var (arr , axis = 0 , correction = 1 ))
325329
326330 rtol = 1.0e-5 if array_type .flags & Flags .Gpu else 1.0e-7
0 commit comments