Skip to content

Commit febaf24

Browse files
amalia-k510pre-commit-ci[bot]flying-sheep
authored
Adding array-api-compat fallback (#159)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Philipp A. <flying-sheep@web.de>
1 parent a54356d commit febaf24

17 files changed

Lines changed: 280 additions & 72 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
with:
5757
python-version: ${{ matrix.env.python }}
5858
- name: create environment
59-
run: uvx hatch env create ${{ matrix.env.name }}
59+
run: uvx hatch -v env create ${{ matrix.env.name }}
6060
- name: run tests with coverage
6161
run: |
6262
uvx hatch run ${{ matrix.env.name }}:run-cov

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ repos:
3232
- array-api-compat>=1.13
3333
- dask>=2026.1
3434
- h5py>=3.15
35+
- jax>=0.10
3536
- numba>=0.63
3637
- packaging>=26
3738
- pytest>=9

pyproject.toml

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ classifiers = [
2222
"Programming Language :: Python :: 3.13",
2323
"Programming Language :: Python :: 3.14",
2424
]
25-
dynamic = [ "description", "readme", "version" ]
26-
dependencies = [ "numpy>=2" ]
25+
dynamic = [ "version" ]
26+
dependencies = [ "array-api-compat", "numpy>=2" ]
2727
optional-dependencies.accel = [ "numba>=0.57" ]
2828
optional-dependencies.dask = [ "dask>=2023.6.1" ]
2929
optional-dependencies.full = [ "fast-array-utils[accel,dask,sparse]", "h5py", "zarr" ]
@@ -37,9 +37,10 @@ entry-points.pytest11.fast_array_utils = "testing.fast_array_utils.pytest"
3737
[dependency-groups]
3838
test = [
3939
"anndata",
40-
"fast-array-utils[accel]",
40+
"fast-array-utils[full]",
41+
"jax",
42+
"jaxlib",
4143
"scikit-learn",
42-
"zarr",
4344
{ include-group = "test-min" },
4445
]
4546
doc = [
@@ -66,12 +67,8 @@ envs.docs.scripts.clean = "git clean -fdX docs"
6667
envs.docs.scripts.open = "python -m webbrowser -t docs/_build/html/index.html"
6768
envs.hatch-test.default-args = []
6869
envs.hatch-test.dependency-groups = [ "test-min" ]
69-
# TODO: remove scipy once https://github.com/pypa/hatch/pull/2127 is released
70-
envs.hatch-test.extra-dependencies = [ "ipykernel", "ipycytoscape", "scipy" ]
70+
envs.hatch-test.extra-dependencies = [ "ipykernel", "ipycytoscape" ]
7171
envs.hatch-test.env-vars.CODSPEED_PROFILE_FOLDER = "test-data/codspeed"
72-
envs.hatch-test.overrides.matrix.extras.features = [
73-
{ if = [ "full" ], value = "full" },
74-
]
7572
envs.hatch-test.overrides.matrix.extras.dependency-groups = [
7673
{ if = [ "full" ], value = "test" },
7774
]
@@ -85,9 +82,10 @@ envs.hatch-test.matrix = [
8582
{ python = [ "3.14", "3.12" ], extras = [ "full", "min" ] },
8683
{ python = [ "3.12" ], extras = [ "full" ], resolution = [ "lowest" ] },
8784
]
88-
metadata.hooks.docstring-description = {}
89-
metadata.hooks.fancy-pypi-readme.content-type = "text/x-rst"
90-
metadata.hooks.fancy-pypi-readme.fragments = [ { path = "README.rst", start-after = ".. begin" } ]
85+
# TODO: re-activate incl. `dynamic = [ "description", "readme", ... ]` after https://github.com/pypa/hatch/issues/2252
86+
# metadata.hooks.docstring-description = {}
87+
# metadata.hooks.fancy-pypi-readme.content-type = "text/x-rst"
88+
# metadata.hooks.fancy-pypi-readme.fragments = [ { path = "README.rst", start-after = ".. begin" } ]
9189
version.source = "vcs"
9290
version.raw-options = { local_scheme = "no-local-version" } # be able to publish dev version
9391

src/fast_array_utils/conv/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,19 @@ def to_dense(x: GpuArray | types.CupySpMatrix, /, *, order: Literal["K", "A", "C
3838
def to_dense(x: GpuArray | types.CupySpMatrix, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[True]) -> NDArray[Any]: ...
3939

4040

41+
@overload
42+
def to_dense[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[False] = False) -> A: ...
43+
@overload
44+
def to_dense[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: Literal[True]) -> NDArray[Any]: ...
45+
46+
4147
def to_dense(
42-
x: CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix,
48+
x: CpuArray | GpuArray | DiskArray | types.CSDataset | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix | types.HasArrayNamespace,
4349
/,
4450
*,
4551
order: Literal["K", "A", "C", "F"] = "K",
4652
to_cpu_memory: bool = False,
47-
) -> NDArray[Any] | types.DaskArray | types.CupyArray:
53+
) -> NDArray[Any] | types.DaskArray | types.CupyArray | types.HasArrayNamespace:
4854
r"""Convert x to a dense array.
4955
5056
If ``to_cpu_memory`` is :data:`False`, :class:`dask.array.Array`\ s and

src/fast_array_utils/conv/_to_dense.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# fallback’s arg0 type has to include types of registered functions
2222
@singledispatch
2323
def to_dense_(
24-
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix,
24+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix | types.HasArrayNamespace,
2525
/,
2626
*,
2727
order: Literal["K", "A", "C", "F"] = "K",
@@ -39,6 +39,13 @@ def _to_dense_cs(x: types.spmatrix | types.sparray, /, *, order: Literal["K", "A
3939
return scipy.to_dense(x, order=sparse_order(x, order=order))
4040

4141

42+
@to_dense_.register(np.ndarray)
43+
def _to_dense_numpy(x: np.ndarray, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> np.ndarray:
44+
# to bypass the _to_dense_array_api path
45+
del to_cpu_memory
46+
return np.asarray(x, order=order)
47+
48+
4249
@to_dense_.register(types.DaskArray)
4350
def _to_dense_dask(x: types.DaskArray, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> NDArray[Any] | types.DaskArray:
4451
from . import to_dense
@@ -69,6 +76,13 @@ def _to_dense_cupy(x: GpuArray, /, *, order: Literal["K", "A", "C", "F"] = "K",
6976
return x.get(order="A") if to_cpu_memory else x
7077

7178

79+
@to_dense_.register(types.HasArrayNamespace)
80+
def _to_dense_array_api[A: types.HasArrayNamespace](x: A, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> A | np.ndarray:
81+
if to_cpu_memory:
82+
return np.asarray(x, order=order)
83+
return x
84+
85+
7286
def sparse_order(x: types.spmatrix | types.sparray | types.CupySpMatrix | types.CSDataset, /, *, order: Literal["K", "A", "C", "F"]) -> Literal["C", "F"]:
7387
if TYPE_CHECKING:
7488
from scipy.sparse._base import _spbase

src/fast_array_utils/stats/__init__.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@ def is_constant(x: NDArray[Any] | types.CSBase, /, *, axis: Literal[0, 1]) -> ND
3737
def is_constant(x: types.CupyArray, /, *, axis: Literal[0, 1]) -> types.CupyArray: ...
3838
@overload
3939
def is_constant(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None) -> types.DaskArray: ...
40+
@overload
41+
def is_constant[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None) -> bool | A: ...
4042

4143

4244
def is_constant(
43-
x: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray,
45+
x: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray | types.HasArrayNamespace,
4446
/,
4547
*,
4648
axis: Literal[0, 1] | None = None,
47-
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray:
49+
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray | types.HasArrayNamespace:
4850
"""Check whether values in array are constant.
4951
5052
Parameters
@@ -90,15 +92,17 @@ def mean(x: CpuArray | DiskArray, /, *, axis: Literal[0, 1], dtype: DTypeLike |
9092
def mean(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None) -> types.CupyArray: ...
9193
@overload
9294
def mean(x: types.DaskArray, /, *, axis: Literal[0, 1], dtype: ToDType[Any] | None = None) -> types.DaskArray: ...
95+
@overload
96+
def mean[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None) -> A: ...
9397

9498

9599
def mean(
96-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
100+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
97101
/,
98102
*,
99103
axis: Literal[0, 1] | None = None,
100104
dtype: DTypeLike | None = None,
101-
) -> NDArray[np.number[Any]] | types.CupyArray | np.number[Any] | types.DaskArray:
105+
) -> NDArray[np.number[Any]] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace:
102106
"""Mean over both or one axis.
103107
104108
Parameters
@@ -145,10 +149,10 @@ def mean_var(x: CpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tup
145149
def mean_var(x: GpuArray, /, *, axis: Literal[0, 1], correction: int = 0) -> tuple[types.CupyArray, types.CupyArray]: ...
146150
@overload
147151
def mean_var(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, correction: int = 0) -> tuple[types.DaskArray, types.DaskArray]: ...
148-
149-
152+
@overload
153+
def mean_var[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, correction: int = 0) -> tuple[A, A]: ...
150154
def mean_var(
151-
x: CpuArray | GpuArray | types.DaskArray,
155+
x: CpuArray | GpuArray | types.DaskArray | types.HasArrayNamespace,
152156
/,
153157
*,
154158
axis: Literal[0, 1] | None = None,
@@ -158,6 +162,7 @@ def mean_var(
158162
| tuple[NDArray[np.float64], NDArray[np.float64]]
159163
| tuple[types.CupyArray, types.CupyArray]
160164
| tuple[types.DaskArray, types.DaskArray]
165+
| tuple[types.HasArrayNamespace, types.HasArrayNamespace]
161166
):
162167
"""Mean and variance over both or one axis.
163168
@@ -214,13 +219,13 @@ def _mk_generic_op(op: DtypeOps) -> StatFunDtype: ...
214219
# https://github.com/scverse/fast-array-utils/issues/52
215220
def _mk_generic_op(op: Ops) -> StatFunNoDtype | StatFunDtype:
216221
def _generic_op(
217-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
222+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
218223
/,
219224
*,
220225
axis: Literal[0, 1] | None = None,
221226
dtype: DTypeLike | None = None,
222227
keep_cupy_as_array: bool = False,
223-
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
228+
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray | types.HasArrayNamespace:
224229
from ._generic_ops import generic_op
225230

226231
assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation {op!r}"
@@ -249,8 +254,10 @@ def min(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ
249254
def min(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ...
250255
@overload
251256
def min(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ...
257+
@overload
258+
def min[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> A: ...
252259
def min(
253-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
260+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
254261
/,
255262
*,
256263
axis: Literal[0, 1] | None = None,
@@ -304,8 +311,10 @@ def max(x: GpuArray, /, *, axis: None, keep_cupy_as_array: Literal[True]) -> typ
304311
def max(x: GpuArray, /, *, axis: Literal[0, 1], keep_cupy_as_array: bool = False) -> types.CupyArray: ...
305312
@overload
306313
def max(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ...
314+
@overload
315+
def max[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, keep_cupy_as_array: bool = False) -> A: ...
307316
def max(
308-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
317+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
309318
/,
310319
*,
311320
axis: Literal[0, 1] | None = None,
@@ -359,14 +368,16 @@ def sum(x: GpuArray, /, *, axis: None = None, dtype: DTypeLike | None = None, ke
359368
def sum(x: GpuArray, /, *, axis: Literal[0, 1], dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.CupyArray: ...
360369
@overload
361370
def sum(x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> types.DaskArray: ...
371+
@overload
372+
def sum[A: types.HasArrayNamespace](x: A, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, keep_cupy_as_array: bool = False) -> A: ...
362373
def sum(
363-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
374+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
364375
/,
365376
*,
366377
axis: Literal[0, 1] | None = None,
367378
dtype: DTypeLike | None = None,
368379
keep_cupy_as_array: bool = False,
369-
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray:
380+
) -> NDArray[Any] | types.CupyArray | np.number[Any] | types.DaskArray | types.HasArrayNamespace:
370381
"""Sum over both or one axis.
371382
372383
Parameters

src/fast_array_utils/stats/_generic_ops.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,32 @@
2222
type ComplexAxis = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None
2323

2424

25-
def _run_numpy_op(
26-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
25+
@singledispatch
26+
def generic_op(
27+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
28+
/,
2729
op: Ops,
2830
*,
2931
axis: Literal[0, 1] | None = None,
3032
dtype: DTypeLike | None = None,
33+
keep_cupy_as_array: bool = False,
3134
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
32-
arr = cast("NDArray[Any] | np.number[Any] | types.CupyArray | types.CupyCOOMatrix | types.DaskArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
33-
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr
35+
raise NotImplementedError # pragma: no cover
3436

3537

36-
@singledispatch
37-
def generic_op(
38-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
38+
@generic_op.register(np.ndarray | types.H5Dataset | types.ZarrArray)
39+
# register explicitly to avoid the array API path and performance slow down
40+
def _generic_op_numpy_disk(
41+
x: np.ndarray | DiskArray,
3942
/,
4043
op: Ops,
4144
*,
4245
axis: Literal[0, 1] | None = None,
4346
dtype: DTypeLike | None = None,
4447
keep_cupy_as_array: bool = False,
45-
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
48+
) -> NDArray[Any] | np.number[Any]:
4649
del keep_cupy_as_array
47-
if TYPE_CHECKING:
48-
# these are never passed to this fallback function, but `singledispatch` wants them
49-
assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix)
50-
# np supports these, but doesn’t know it. (TODO: test cupy)
51-
assert not isinstance(x, types.ZarrArray | types.H5Dataset)
52-
return cast("NDArray[Any] | np.number[Any]", _run_numpy_op(x, op, axis=axis, dtype=dtype))
50+
return getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)) # type: ignore[no-any-return]
5351

5452

5553
@generic_op.register(types.CupyArray | types.CupyCSMatrix)
@@ -62,7 +60,8 @@ def _generic_op_cupy(
6260
dtype: DTypeLike | None = None,
6361
keep_cupy_as_array: bool = False,
6462
) -> types.CupyArray | np.number[Any]:
65-
arr = cast("types.CupyArray", _run_numpy_op(x, op, axis=axis, dtype=dtype))
63+
arr = cast("types.CupyArray | types.CupyCOOMatrix", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
64+
arr = arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr
6665
return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze()
6766

6867

@@ -109,3 +108,22 @@ def _generic_op_dask(
109108
dtype = getattr(np, op)(np.zeros(1, dtype=x.dtype)).dtype
110109

111110
return _dask_inner(x, op, axis=axis, dtype=dtype, keep_cupy_as_array=keep_cupy_as_array)
111+
112+
113+
@generic_op.register(types.HasArrayNamespace)
114+
def _generic_op_array_api[A: types.HasArrayNamespace](
115+
x: A,
116+
/,
117+
op: Ops,
118+
*,
119+
axis: Literal[0, 1] | None = None,
120+
dtype: DTypeLike | None = None,
121+
keep_cupy_as_array: bool = False,
122+
) -> A:
123+
"""Handle arrays with native array API support."""
124+
del keep_cupy_as_array
125+
126+
import array_api_compat
127+
128+
xp = array_api_compat.array_namespace(x)
129+
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op)) # type: ignore[no-any-return]

src/fast_array_utils/stats/_is_constant.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33

44
from functools import partial, singledispatch
5-
from typing import TYPE_CHECKING, cast
5+
from typing import TYPE_CHECKING
66

77
import numba
88
import numpy as np
@@ -19,29 +19,29 @@
1919

2020
@singledispatch
2121
def is_constant_(
22-
a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray,
22+
a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray | types.HasArrayNamespace,
2323
/,
2424
*,
2525
axis: Literal[0, 1] | None = None,
2626
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray: # pragma: no cover
27-
raise NotImplementedError
27+
raise NotImplementedError # pragma: no cover
2828

2929

30-
@is_constant_.register(np.ndarray | types.CupyArray)
30+
@is_constant_.register(np.ndarray | types.CupyArray | types.HasArrayNamespace)
3131
def _is_constant_ndarray(a: NDArray[Any] | types.CupyArray, /, *, axis: Literal[0, 1] | None = None) -> bool | NDArray[np.bool] | types.CupyArray:
3232
# Should eventually support nd, not now.
3333
match axis:
3434
case None:
35-
return bool((a == a.flat[0]).all())
35+
return bool((a == a.reshape(-1)[0]).all())
3636
case 0:
3737
return _is_constant_rows(a.T)
3838
case 1:
3939
return _is_constant_rows(a)
4040

4141

4242
def _is_constant_rows(a: NDArray[Any] | types.CupyArray) -> NDArray[np.bool] | types.CupyArray:
43-
b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape)
44-
return cast("NDArray[np.bool]", (a == b).all(axis=1))
43+
# broadcasts without needing np.broadcast_to
44+
return (a == a[:, 0:1]).all(axis=1)
4545

4646

4747
@is_constant_.register(types.CSBase)

src/fast_array_utils/stats/_mean.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818

1919

2020
def mean_(
21-
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
21+
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
2222
/,
2323
*,
2424
axis: Literal[0, 1] | None = None,
2525
dtype: DTypeLike | None = None,
26-
) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray:
26+
) -> NDArray[np.number[Any]] | np.number[Any] | types.DaskArray | types.HasArrayNamespace:
2727
total = sum(x, axis=axis, dtype=dtype) # type: ignore[misc,arg-type]
2828
n = np.prod(x.shape) if axis is None else x.shape[axis]
2929
return total / n # type: ignore[no-any-return]

0 commit comments

Comments
 (0)