Skip to content

Commit 6d3891e

Browse files
committed
types, missing parameters
1 parent c77a1bc commit 6d3891e

5 files changed

Lines changed: 15 additions & 7 deletions

File tree

src/fast_array_utils/stats/_is_constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
@singledispatch
2121
def is_constant_(
22-
a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray,
22+
a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray | types.HasArrayNamespace,
2323
/,
2424
*,
2525
axis: Literal[0, 1] | None = None,

src/fast_array_utils/stats/_mean.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
def mean_(
21-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
21+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
2222
/,
2323
*,
2424
axis: Literal[0, 1] | None = None,

src/fast_array_utils/stats/_mean_var.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@no_type_check # mypy is extremely confused
2323
def mean_var_(
24-
x: CpuArray | GpuArray | types.DaskArray,
24+
x: CpuArray | GpuArray | types.DaskArray | types.HasArrayNamespace,
2525
/,
2626
*,
2727
axis: Literal[0, 1] | None = None,

src/fast_array_utils/stats/_typing.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,27 @@ class StatFunNoDtype(Protocol):
2828
__name__: str
2929

3030
def __call__(
31-
self, x: CpuArray | GpuArray | DiskArray | types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False
32-
) -> types.DaskArray: ...
31+
self,
32+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
33+
/,
34+
*,
35+
axis: Literal[0, 1] | None = None,
36+
keep_cupy_as_array: bool = False,
37+
) -> types.DaskArray | types.HasArrayNamespace: ...
3338

3439

3540
class StatFunDtype(Protocol):
3641
__name__: str
3742

3843
def __call__(
3944
self,
40-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
45+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
4146
/,
4247
*,
4348
axis: Literal[0, 1] | None = None,
4449
dtype: DTypeLike | None = None,
4550
keep_cupy_as_array: bool = False,
46-
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray: ...
51+
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace: ...
4752

4853

4954
NoDtypeOps = Literal["max", "min"]

src/fast_array_utils/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,7 @@
124124
class HasArrayNamespace(Protocol):
125125
"""An array object compatible with the Python array API standard."""
126126

127+
ndim: int
128+
shape: tuple[int, ...]
129+
127130
def __array_namespace__(self, /, *, api_version: str | None = None) -> ModuleType: ...

0 commit comments

Comments
 (0)