Skip to content

Commit c890bc3

Browse files
committed
fix types
1 parent 592b7b7 commit c890bc3

6 files changed

Lines changed: 31 additions & 13 deletions

File tree

src/fast_array_utils/stats/_typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __call__(
3434
*,
3535
axis: Literal[0, 1] | None = None,
3636
keep_cupy_as_array: bool = False,
37-
) -> types.DaskArray | types.HasArrayNamespace: ...
37+
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace: ...
3838

3939

4040
class StatFunDtype(Protocol):

src/fast_array_utils/types.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from typing import TYPE_CHECKING, Protocol, runtime_checkable
88

99

10+
if TYPE_CHECKING:
11+
from types import ModuleType
12+
13+
1014
__all__ = [
1115
"COOBase",
1216
"CSArray",
@@ -22,6 +26,7 @@
2226
"DaskArray",
2327
"H5Dataset",
2428
"H5Group",
29+
"HasArrayNamespace",
2530
"ZarrArray",
2631
"ZarrGroup",
2732
"coo_array",
@@ -37,8 +42,6 @@
3742

3843
# scipy sparse
3944
if TYPE_CHECKING:
40-
from types import ModuleType
41-
4245
from scipy.sparse import coo_array, coo_matrix, csc_array, csc_matrix, csr_array, csr_matrix, sparray, spmatrix
4346
else:
4447
try: # cs?_array isn’t available in older scipy versions
@@ -124,7 +127,17 @@
124127
class HasArrayNamespace(Protocol):
125128
"""An array object compatible with the Python array API standard."""
126129

127-
ndim: int
128-
shape: tuple[int, ...]
130+
@property
131+
def ndim(self) -> int:
132+
"""The number of dimensions of the array."""
133+
134+
@property
135+
def shape(self) -> tuple[int, ...]:
136+
"""The shape of the array."""
137+
138+
@property
139+
def dtype(self) -> object:
140+
"""The data type of the array."""
129141

130-
def __array_namespace__(self, /, *, api_version: str | None = None) -> ModuleType: ...
142+
def __array_namespace__(self, /, *, api_version: str | None = None) -> ModuleType:
143+
"""Get Array API namespace."""

tests/test_jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
if TYPE_CHECKING:
1515
from typing import Literal
1616

17+
1718
pytestmark = pytest.mark.skipif(not find_spec("jax"), reason="jax not installed")
1819

1920
if find_spec("jax"):

tests/test_stats.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,11 @@ def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]:
103103
return np_arr
104104

105105

106-
def to_np_dense_checked(
107-
stat: NDArray[DTypeOut] | np.number[Any] | types.DaskArray, axis: Literal[0, 1] | None, arr: CpuArray | GpuArray | DiskArray | types.DaskArray
108-
) -> NDArray[DTypeOut] | np.number[Any]:
106+
def to_np_dense_checked[DT: DTypeOut](
107+
stat: NDArray[DT] | np.number[Any] | types.DaskArray | types.HasArrayNamespace,
108+
axis: Literal[0, 1] | None,
109+
arr: CpuArray | GpuArray | DiskArray | types.COOBase | types.DaskArray | types.HasArrayNamespace,
110+
) -> NDArray[DT] | np.number[Any]:
109111
match axis, arr:
110112
case _, types.DaskArray():
111113
assert isinstance(stat, types.DaskArray), type(stat)
@@ -208,7 +210,7 @@ def test_min_max(array_type: ArrayType[CpuArray | GpuArray | DiskArray | types.D
208210
np_arr = rng.random((100, 100))
209211
arr = array_type(np_arr)
210212

211-
result = to_np_dense_checked(func(arr, axis=axis), axis, arr)
213+
result = to_np_dense_checked(func(arr, axis=axis), axis, arr) # type: ignore[arg-type]
212214

213215
expected = (np.min if func is stats.min else np.max)(np_arr, axis=axis)
214216
np.testing.assert_array_equal(result, expected)
@@ -229,7 +231,7 @@ def test_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1]
229231
np_arr = np.array(data, dtype=np.float32)
230232
arr = array_type(np_arr)
231233
assert 1 in arr.chunksize, "This test is supposed to test 1×n and n×1 chunk sizes"
232-
stat = cast("NDArray[Any] | types.CupyArray", func(arr, axis=axis).compute())
234+
stat = cast("NDArray[Any] | types.CupyArray", func(arr, axis=axis).compute()) # type: ignore[union-attr]
233235
if isinstance(stat, types.CupyArray):
234236
stat = stat.get()
235237
np_func = getattr(np, func.__name__)
@@ -321,6 +323,8 @@ def test_mean_var_pbmc_dask(array_type: ArrayType[types.DaskArray], pbmc64k_redu
321323
arr = array_type(mat)
322324

323325
mean_mat, var_mat = stats.mean_var(mat, axis=0, correction=1)
326+
mean_arr: NDArray[Any] | np.number # actually just NDArray, and mypy should be able to infer.
327+
var_arr: NDArray[Any] | np.number
324328
mean_arr, var_arr = (to_np_dense_checked(a, 0, arr) for a in stats.mean_var(arr, axis=0, correction=1))
325329

326330
rtol = 1.0e-5 if array_type.flags & Flags.Gpu else 1.0e-7

typings/cupyx/scipy/sparse/_coo.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ from ._base import spmatrix
88

99
class coo_matrix(spmatrix):
1010
format: Literal["coo"] = "coo"
11-
def get(self, stream: cupy.cuda.Stream | None = None) -> sps.spmatrix: ...
11+
def get(self, stream: cupy.cuda.Stream | None = None) -> sps.coo_matrix: ...

typings/cupyx/scipy/sparse/_csr.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ from ._compressed import _compressed_sparse_matrix
88

99
class csr_matrix(_compressed_sparse_matrix):
1010
format: Literal["csr"] = "csr"
11-
def get(self, stream: cupy.cuda.Stream | None = None) -> sps.csc_matrix: ...
11+
def get(self, stream: cupy.cuda.Stream | None = None) -> sps.csr_matrix: ...

0 commit comments

Comments
 (0)