Skip to content

Commit 3d4ee3a

Browse files
committed
types
1 parent d22f74b commit 3d4ee3a

5 files changed

Lines changed: 22 additions & 19 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ repos:
3232
- array-api-compat>=1.13
3333
- dask>=2026.1
3434
- h5py>=3.15
35+
- jax>=0.10
3536
- numba>=0.63
3637
- packaging>=26
3738
- pytest>=9

src/fast_array_utils/conv/_to_dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _to_dense_numpy(x: np.ndarray, /, *, order: Literal["K", "A", "C", "F"] = "K
4747

4848

4949
@to_dense_.register(types.HasArrayNamespace)
50-
def _to_dense_array_api(x: types.HasArrayNamespace, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> Any: # noqa: ANN401
50+
def _to_dense_array_api[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> A | np.ndarray:
5151
if to_cpu_memory:
5252
return np.asarray(x, order=order)
5353
return x

src/fast_array_utils/stats/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,13 @@ def _mk_generic_op(op: DtypeOps) -> StatFunDtype: ...
214214
# https://github.com/scverse/fast-array-utils/issues/52
215215
def _mk_generic_op(op: Ops) -> StatFunNoDtype | StatFunDtype:
216216
def _generic_op(
217-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
217+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
218218
/,
219219
*,
220220
axis: Literal[0, 1] | None = None,
221221
dtype: DTypeLike | None = None,
222222
keep_cupy_as_array: bool = False,
223-
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
223+
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray | types.HasArrayNamespace:
224224
from ._generic_ops import generic_op
225225

226226
assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation {op!r}"
@@ -359,14 +359,16 @@ def sum(x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, ke
359359
def sum(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.CupyArray: ...
360360
@overload
361361
def sum(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ...
362+
@overload
363+
def sum[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> A: ...
362364
def sum(
363-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
365+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
364366
/,
365367
*,
366368
axis: Literal[0, 1] | None = None,
367369
dtype: DTypeLike | None = None,
368370
keep_cupy_as_array: bool = False,
369-
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray:
371+
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace:
370372
"""Sum over both or one axis.
371373
372374
Parameters

src/fast_array_utils/stats/_generic_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,22 +59,22 @@ def _generic_op_numpy(
5959

6060

6161
@generic_op.register(types.HasArrayNamespace)
62-
def _generic_op_array_api(
63-
x: types.HasArrayNamespace,
62+
def _generic_op_array_api[A: types.HasArrayNamespace](
63+
x: A,
6464
/,
6565
op: Ops,
6666
*,
6767
axis: Literal[0, 1] | None = None,
6868
dtype: DTypeLike | None = None,
6969
keep_cupy_as_array: bool = False,
70-
) -> Any: # noqa: ANN401
70+
) -> A:
7171
"""Handle arrays with native array API support."""
7272
del keep_cupy_as_array
7373

7474
import array_api_compat
7575

7676
xp = array_api_compat.array_namespace(x)
77-
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op))
77+
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op)) # type: ignore[no-any-return]
7878

7979

8080
@generic_op.register(types.CupyArray | types.CupyCSMatrix)

tests/test_jax.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
if TYPE_CHECKING:
15-
from typing import Any, Literal
15+
from typing import Literal
1616

1717
pytestmark = pytest.mark.skipif(not find_spec("jax"), reason="jax not installed")
1818

@@ -21,18 +21,18 @@
2121
# problem as mean_var passes dtype= np.float64 internally, which crashes without this fix
2222
import jax
2323

24-
jax.config.update("jax_enable_x64", True) # noqa: FBT003
24+
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call] # noqa: FBT003
2525

2626

2727
@pytest.fixture
28-
def jax_arr() -> Any: # noqa: ANN401
28+
def jax_arr() -> jax.Array:
2929
import jax.numpy as jnp
3030

3131
return jnp.array([[1, 0], [2, 0], [3, 0]], dtype=jnp.float32)
3232

3333

3434
@pytest.mark.parametrize("axis", [None, 0, 1])
35-
def test_sum(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
35+
def test_sum(jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None:
3636
import jax.numpy as jnp
3737

3838
result = stats.sum(jax_arr, axis=axis)
@@ -41,7 +41,7 @@ def test_sum(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
4141

4242

4343
@pytest.mark.parametrize("axis", [None, 0, 1])
44-
def test_min(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
44+
def test_min(jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None:
4545
import jax.numpy as jnp
4646

4747
result = stats.min(jax_arr, axis=axis)
@@ -50,7 +50,7 @@ def test_min(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
5050

5151

5252
@pytest.mark.parametrize("axis", [None, 0, 1])
53-
def test_max(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
53+
def test_max(jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None:
5454
import jax.numpy as jnp
5555

5656
result = stats.max(jax_arr, axis=axis)
@@ -59,7 +59,7 @@ def test_max(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
5959

6060

6161
@pytest.mark.parametrize("axis", [None, 0, 1])
62-
def test_mean(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
62+
def test_mean(jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None:
6363
import jax.numpy as jnp
6464

6565
result = stats.mean(jax_arr, axis=axis)
@@ -95,7 +95,7 @@ def test_is_constant(axis: Literal[0, 1] | None) -> None:
9595

9696

9797
@pytest.mark.parametrize("axis", [None, 0, 1])
98-
def test_mean_var(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
98+
def test_mean_var(jax_arr: jax.Array, axis: Literal[0, 1] | None) -> None:
9999
import jax.numpy as jnp
100100

101101
mean, var = stats.mean_var(jax_arr, axis=axis, correction=1)
@@ -108,14 +108,14 @@ def test_mean_var(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: AN
108108
assert jnp.allclose(var, var_expected)
109109

110110

111-
def test_to_dense(jax_arr: Any) -> None: # noqa: ANN401
111+
def test_to_dense(jax_arr: jax.Array) -> None:
112112
import jax.numpy as jnp
113113

114114
result = to_dense(jax_arr)
115115
assert jnp.array_equal(result, jax_arr)
116116

117117

118-
def test_to_dense_to_cpu(jax_arr: Any) -> None: # noqa: ANN401
118+
def test_to_dense_to_cpu(jax_arr: jax.Array) -> None:
119119
result = to_dense(jax_arr, to_cpu_memory=True)
120120
assert isinstance(result, np.ndarray)
121121
np.testing.assert_array_equal(result, np.asarray(jax_arr))

0 commit comments

Comments
 (0)