Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
64f5304
array-api initially implementation
amalia-k510 Mar 23, 2026
5373050
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2026
f016c39
updates in regards to the handler and some array_api handling fixes
amalia-k510 Apr 15, 2026
cca3ad6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2026
c1bb155
Issues with jax test are fixed, introduced similar tests with pytorch
amalia-k510 Apr 15, 2026
c9a8f85
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 15, 2026
bd08c2e
pre-commit fixes
amalia-k510 Apr 15, 2026
b2e3f9b
mipy issues fix
amalia-k510 Apr 15, 2026
65e83a6
speed up fix
amalia-k510 Apr 15, 2026
2a1924d
Update src/fast_array_utils/stats/_mean_var.py
amalia-k510 Apr 17, 2026
9684213
Addressed the comments
amalia-k510 Apr 17, 2026
1e3c296
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2026
63c5e16
chore: simplify
flying-sheep Apr 20, 2026
9c8466a
addressing comments about removing is_array_api_obj check
amalia-k510 Apr 27, 2026
86a7503
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
e28f176
import fix
amalia-k510 Apr 27, 2026
c704259
ignore comments update and mypy test
amalia-k510 Apr 27, 2026
ab6e200
Merge branch 'main' into array-api-implementation
amalia-k510 Apr 27, 2026
aaccbda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
7e5102a
main version
amalia-k510 Apr 27, 2026
eed16a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
567107a
residues ignore comments removed
amalia-k510 Apr 27, 2026
1e41d24
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
aab50f7
pyproject, jax optional dependencies
amalia-k510 Apr 27, 2026
91ef896
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
0f62abc
commented addressed, mypy try again
amalia-k510 Apr 27, 2026
d642668
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
37a6634
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
bab4392
mypy comment add
amalia-k510 Apr 27, 2026
d22f74b
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
3d4ee3a
types
flying-sheep Apr 27, 2026
c77a1bc
types for others
amalia-k510 Apr 29, 2026
6d3891e
types, missing parameters
amalia-k510 Apr 29, 2026
9e2a9fe
revert pyproject.toml
amalia-k510 Apr 29, 2026
57117d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
0415fd5
rework deps
flying-sheep Apr 30, 2026
3857917
Merge branch 'main' into array-api-implementation
flying-sheep Apr 30, 2026
592b7b7
fix deps
flying-sheep Apr 30, 2026
c890bc3
fix types
flying-sheep Apr 30, 2026
49283b0
fix cupy tests
flying-sheep Apr 30, 2026
f6463cc
fix disk array
flying-sheep Apr 30, 2026
474c969
fmt
flying-sheep Apr 30, 2026
5de4e5b
coverage
flying-sheep Apr 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [ "numpy>=2" ]
optional-dependencies.accel = [ "numba>=0.57" ]
optional-dependencies.dask = [ "dask>=2023.6.1" ]
optional-dependencies.full = [ "fast-array-utils[accel,dask,sparse]", "h5py", "zarr" ]
optional-dependencies.jax = [ "jax", "jaxlib" ]
optional-dependencies.sparse = [ "scipy>=1.13" ]
optional-dependencies.testing = [ "packaging" ]
urls."Issue Tracker" = "https://github.com/scverse/fast-array-utils/issues"
Expand Down Expand Up @@ -71,6 +72,7 @@ envs.hatch-test.extra-dependencies = [ "ipykernel", "ipycytoscape", "scipy" ]
envs.hatch-test.env-vars.CODSPEED_PROFILE_FOLDER = "test-data/codspeed"
envs.hatch-test.overrides.matrix.extras.features = [
{ if = [ "full" ], value = "full" },
{ if = [ "full" ], value = "jax" },
]
envs.hatch-test.overrides.matrix.extras.dependency-groups = [
{ if = [ "full" ], value = "test" },
Expand Down
16 changes: 15 additions & 1 deletion src/fast_array_utils/conv/_to_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# fallback’s arg0 type has to include types of registered functions
@singledispatch
def to_dense_(
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix,
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.sparray | types.spmatrix | types.CupySpMatrix | types.HasArrayNamespace,
/,
*,
order: Literal["K", "A", "C", "F"] = "K",
Expand All @@ -39,6 +39,20 @@ def _to_dense_cs(x: types.spmatrix | types.sparray, /, *, order: Literal["K", "A
return scipy.to_dense(x, order=sparse_order(x, order=order))


@to_dense_.register(np.ndarray)
def _to_dense_numpy(x: np.ndarray, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> np.ndarray:
# to bypass the _to_dense_array_api path
del to_cpu_memory
return np.asarray(x, order=order)


@to_dense_.register(types.HasArrayNamespace)
def _to_dense_array_api(x: types.HasArrayNamespace, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> Any: # noqa: ANN401
if to_cpu_memory:
return np.asarray(x, order=order)
return x


@to_dense_.register(types.DaskArray)
def _to_dense_dask(x: types.DaskArray, /, *, order: Literal["K", "A", "C", "F"] = "K", to_cpu_memory: bool = False) -> NDArray[Any] | types.DaskArray:
from . import to_dense
Expand Down
54 changes: 40 additions & 14 deletions src/fast_array_utils/stats/_generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,9 @@
type ComplexAxis = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None


def _run_numpy_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
arr = cast("NDArray[Any] | np.number[Any] | types.CupyArray | types.CupyCOOMatrix | types.DaskArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr


@singledispatch
def generic_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
/,
op: Ops,
*,
Expand All @@ -49,7 +38,44 @@ def generic_op(
assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix)
# np supports these, but doesn’t know it. (TODO: test cupy)
assert not isinstance(x, types.ZarrArray | types.H5Dataset)
return cast("NDArray[Any] | np.number[Any]", _run_numpy_op(x, op, axis=axis, dtype=dtype))

arr = getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr


@generic_op.register(np.ndarray)
# register explicitly to avoid the array API path and performance slow down
def _generic_op_numpy(
x: np.ndarray,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any]:
del keep_cupy_as_array
arr = getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr # type: ignore[return-value]
Comment thread
flying-sheep marked this conversation as resolved.
Outdated


@generic_op.register(types.HasArrayNamespace)
def _generic_op_array_api(
x: types.HasArrayNamespace,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> Any: # noqa: ANN401
"""Handle arrays with native array API support."""
del keep_cupy_as_array

import array_api_compat

xp = array_api_compat.array_namespace(x)
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op))


@generic_op.register(types.CupyArray | types.CupyCSMatrix)
Expand All @@ -62,7 +88,7 @@ def _generic_op_cupy(
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.CupyArray | np.number[Any]:
arr = cast("types.CupyArray", _run_numpy_op(x, op, axis=axis, dtype=dtype))
arr = cast("types.CupyArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze()


Expand Down
10 changes: 5 additions & 5 deletions src/fast_array_utils/stats/_is_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

from functools import partial, singledispatch
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

import numba
import numpy as np
Expand All @@ -27,21 +27,21 @@ def is_constant_(
raise NotImplementedError


@is_constant_.register(np.ndarray | types.CupyArray)
@is_constant_.register(np.ndarray | types.CupyArray | types.HasArrayNamespace)
def _is_constant_ndarray(a: NDArray[Any] | types.CupyArray, /, *, axis: Literal[0, 1] | None = None) -> bool | NDArray[np.bool] | types.CupyArray:
# Should eventually support nd, not now.
match axis:
case None:
return bool((a == a.flat[0]).all())
return bool((a == a.reshape(-1)[0]).all())
case 0:
return _is_constant_rows(a.T)
case 1:
return _is_constant_rows(a)


def _is_constant_rows(a: NDArray[Any] | types.CupyArray) -> NDArray[np.bool] | types.CupyArray:
b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape)
return cast("NDArray[np.bool]", (a == b).all(axis=1))
# broadcasts without needing np.broadcast_to
return (a == a[:, 0:1]).all(axis=1)


@is_constant_.register(types.CSBase)
Expand Down
13 changes: 11 additions & 2 deletions src/fast_array_utils/stats/_mean_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,22 @@ def mean_var_(
| tuple[np.float64, np.float64]
| tuple[types.DaskArray, types.DaskArray]
):

from . import mean

if isinstance(x, np.ndarray | types.CSBase):
xp = np
elif isinstance(x, types.HasArrayNamespace):
import array_api_compat

xp = array_api_compat.array_namespace(x)
else:
xp = np
if axis is not None and isinstance(x, types.CSBase):
mean_, var = _sparse_mean_var(x, axis=axis)
else:
mean_ = mean(x, axis=axis, dtype=np.float64)
mean_sq = mean(power(x, 2, dtype=np.float64), axis=axis) if isinstance(x, types.DaskArray) else mean(power(x, 2), axis=axis, dtype=np.float64)
mean_ = mean(x, axis=axis, dtype=xp.float64)
mean_sq = mean(power(x, 2, dtype=xp.float64), axis=axis) if isinstance(x, types.DaskArray) else mean(power(x, 2), axis=axis, dtype=xp.float64)
var = mean_sq - mean_**2
if correction: # R convention == 1 (unbiased estimator)
n = np.prod(x.shape) if axis is None else x.shape[axis]
Expand Down
16 changes: 15 additions & 1 deletion src/fast_array_utils/stats/_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fast_array_utils.typing import CpuArray, GpuArray

# All supported array types except for disk ones and CSDataset
type Array = CpuArray | GpuArray | types.DaskArray
type Array = CpuArray | GpuArray | types.DaskArray | types.HasArrayNamespace


def power[Arr: Array](x: Arr, n: int, /, dtype: DTypeLike | None = None) -> Arr:
Expand All @@ -31,6 +31,20 @@ def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array:
return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator]


@_power.register(np.ndarray)
def _power_numpy(x: np.ndarray, n: int, /, dtype: DTypeLike | None = None) -> np.ndarray:
# avoids slower xp.pow(xp.astype(...)) path
return x**n if dtype is None else np.power(x, n, dtype=dtype)


@_power.register(types.HasArrayNamespace)
def _power_array_api(x: types.HasArrayNamespace, n: int, /, dtype: DTypeLike | None = None) -> types.HasArrayNamespace:
import array_api_compat

xp = array_api_compat.array_namespace(x)
return xp.pow(x, n) if dtype is None else xp.pow(xp.astype(x, dtype), n) # type: ignore[no-any-return]


@_power.register(types.CSBase | types.CupyCSMatrix)
def _power_cs[Mat: types.CSBase | types.CupyCSMatrix](x: Mat, n: int, /, dtype: DTypeLike | None = None) -> Mat:
new_data = power(x.data, n, dtype=dtype)
Expand Down
11 changes: 10 additions & 1 deletion src/fast_array_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol, runtime_checkable


__all__ = [
Expand Down Expand Up @@ -37,6 +37,8 @@

# scipy sparse
if TYPE_CHECKING:
from types import ModuleType

from scipy.sparse import coo_array, coo_matrix, csc_array, csc_matrix, csr_array, csr_matrix, sparray, spmatrix
else:
try: # cs?_array isn’t available in older scipy versions
Expand Down Expand Up @@ -116,3 +118,10 @@
CSRDataset.__module__ = CSCDataset.__module__ = "anndata.abc"
CSDataset = CSRDataset | CSCDataset
"""Anndata sparse out-of-core matrices."""


@runtime_checkable
class HasArrayNamespace(Protocol):
"""An array object compatible with the Python array API standard."""

def __array_namespace__(self, /, *, api_version: str | None = None) -> ModuleType: ...
121 changes: 121 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING

import numpy as np
import pytest

from fast_array_utils import stats
from fast_array_utils.conv import to_dense


if TYPE_CHECKING:
from typing import Any, Literal

pytestmark = pytest.mark.skipif(not find_spec("jax"), reason="jax not installed")

if find_spec("jax"):
# enabling 64-bit precision in JAX as it defaults to 32-bit only
# problem as mean_var passes dtype= np.float64 internally, which crashes without this fix
import jax

jax.config.update("jax_enable_x64", True) # noqa: FBT003


@pytest.fixture
def jax_arr() -> Any: # noqa: ANN401
import jax.numpy as jnp

return jnp.array([[1, 0], [2, 0], [3, 0]], dtype=jnp.float32)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_sum(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
import jax.numpy as jnp

result = stats.sum(jax_arr, axis=axis)
expected = jnp.sum(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_min(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
import jax.numpy as jnp

result = stats.min(jax_arr, axis=axis)
expected = jnp.min(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_max(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
import jax.numpy as jnp

result = stats.max(jax_arr, axis=axis)
expected = jnp.max(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_mean(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
import jax.numpy as jnp

result = stats.mean(jax_arr, axis=axis)
expected = jnp.mean(jax_arr, axis=axis)
assert jnp.allclose(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_is_constant(axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

x = jnp.array(
[
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 0],
],
dtype=jnp.float32,
)
result = stats.is_constant(x, axis=axis)

if axis is None:
assert bool(result) is False
elif axis == 0:
expected = jnp.array([True, True, False, False])
assert jnp.array_equal(result, expected)
else:
expected = jnp.array([False, False, True, True, False, True])
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_mean_var(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
import jax.numpy as jnp

mean, var = stats.mean_var(jax_arr, axis=axis, correction=1)

mean_expected = jnp.mean(jax_arr, axis=axis)
n = jax_arr.size if axis is None else jax_arr.shape[axis]
var_expected = jnp.var(jax_arr, axis=axis) * n / (n - 1)

assert jnp.allclose(mean, mean_expected)
assert jnp.allclose(var, var_expected)


def test_to_dense(jax_arr: Any) -> None: # noqa: ANN401
import jax.numpy as jnp

result = to_dense(jax_arr)
assert jnp.array_equal(result, jax_arr)


def test_to_dense_to_cpu(jax_arr: Any) -> None: # noqa: ANN401
result = to_dense(jax_arr, to_cpu_memory=True)
assert isinstance(result, np.ndarray)
np.testing.assert_array_equal(result, np.asarray(jax_arr))
Loading