Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
64f5304
array-api initially implementation
amalia-k510 Mar 23, 2026
5373050
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2026
f016c39
updates in regards to the handler and some array_api handling fixes
amalia-k510 Apr 15, 2026
cca3ad6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2026
c1bb155
Issues with jax test are fixed, introduced similar tests with pytorch
amalia-k510 Apr 15, 2026
c9a8f85
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 15, 2026
bd08c2e
pre-commit fixes
amalia-k510 Apr 15, 2026
b2e3f9b
mipy issues fix
amalia-k510 Apr 15, 2026
65e83a6
speed up fix
amalia-k510 Apr 15, 2026
2a1924d
Update src/fast_array_utils/stats/_mean_var.py
amalia-k510 Apr 17, 2026
9684213
Addressed the comments
amalia-k510 Apr 17, 2026
1e3c296
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2026
63c5e16
chore: simplify
flying-sheep Apr 20, 2026
9c8466a
addressing comments about removing is_array_api_obj check
amalia-k510 Apr 27, 2026
86a7503
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
e28f176
import fix
amalia-k510 Apr 27, 2026
c704259
ignore comments update and mypy test
amalia-k510 Apr 27, 2026
ab6e200
Merge branch 'main' into array-api-implementation
amalia-k510 Apr 27, 2026
aaccbda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
7e5102a
main version
amalia-k510 Apr 27, 2026
eed16a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
567107a
residues ignore comments removed
amalia-k510 Apr 27, 2026
1e41d24
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
aab50f7
pyproject, jax optional dependencies
amalia-k510 Apr 27, 2026
91ef896
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
0f62abc
commented addressed, mypy try again
amalia-k510 Apr 27, 2026
d642668
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
37a6634
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
bab4392
mypy comment add
amalia-k510 Apr 27, 2026
d22f74b
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
3d4ee3a
types
flying-sheep Apr 27, 2026
c77a1bc
types for others
amalia-k510 Apr 29, 2026
6d3891e
types, missing parameters
amalia-k510 Apr 29, 2026
9e2a9fe
revert pyproject.toml
amalia-k510 Apr 29, 2026
57117d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
0415fd5
rework deps
flying-sheep Apr 30, 2026
3857917
Merge branch 'main' into array-api-implementation
flying-sheep Apr 30, 2026
592b7b7
fix deps
flying-sheep Apr 30, 2026
c890bc3
fix types
flying-sheep Apr 30, 2026
49283b0
fix cupy tests
flying-sheep Apr 30, 2026
f6463cc
fix disk array
flying-sheep Apr 30, 2026
474c969
fmt
flying-sheep Apr 30, 2026
5de4e5b
coverage
flying-sheep Apr 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ classifiers = [
"Programming Language :: Python :: 3.14",
]
dynamic = [ "description", "readme", "version" ]
dependencies = [ "numpy>=2" ]
dependencies = [ "array-api-compat", "numpy>=2" ]
optional-dependencies.accel = [ "numba>=0.57" ]
optional-dependencies.dask = [ "dask>=2023.6.1" ]
optional-dependencies.full = [ "fast-array-utils[accel,dask,sparse]", "h5py", "zarr" ]
optional-dependencies.jax = [ "jax", "jaxlib" ]
optional-dependencies.sparse = [ "scipy>=1.13" ]
optional-dependencies.testing = [ "packaging" ]
optional-dependencies.torch = [ "torch" ]
urls."Issue Tracker" = "https://github.com/scverse/fast-array-utils/issues"
urls."Source Code" = "https://github.com/scverse/fast-array-utils"
urls.Documentation = "https://icb-fast-array-utils.readthedocs-hosted.com/"
Expand All @@ -48,6 +50,15 @@ doc = [
"sphinx>=9.0.1",
"sphinx-autofixture>=0.4.1",
]
# for update-mypy-hook
mypy = [
"fast-array-utils[full]",
"scipy-stubs",
# TODO: replace sphinx with this: { include-group = "doc" },
"sphinx",
"types-docutils",
{ include-group = "test" },
]
Comment thread
amalia-k510 marked this conversation as resolved.
Outdated
test-min = [
"coverage[toml]",
"fast-array-utils[sparse,testing]", # include sparse for testing numba-less to_dense
Expand Down
2 changes: 1 addition & 1 deletion src/fast_array_utils/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ def to_dense(
Dense form of ``x``

"""
return to_dense_(x, order=order, to_cpu_memory=to_cpu_memory)
return to_dense_(x, order=order, to_cpu_memory=to_cpu_memory) # type: ignore[no-any-return]
9 changes: 8 additions & 1 deletion src/fast_array_utils/conv/_to_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@ def to_dense_(
*,
order: Literal["K", "A", "C", "F"] = "K",
to_cpu_memory: bool = False,
) -> NDArray[Any] | types.CupyArray | types.DaskArray:
) -> Any: # noqa: ANN401
import array_api_compat

if not isinstance(x, np.ndarray) and array_api_compat.is_array_api_obj(x):
if to_cpu_memory:
return np.asarray(x, order=order)
return x # array API standard covers dense arrays; sparse types are handled by registered dispatches

del to_cpu_memory # it already is
return np.asarray(x, order=order)

Expand Down
4 changes: 2 additions & 2 deletions src/fast_array_utils/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def is_constant(
from ._is_constant import is_constant_

validate_axis(x.ndim, axis)
return is_constant_(x, axis=axis)
return is_constant_(x, axis=axis) # type: ignore[no-any-return]


# TODO(flying-sheep): support CSDataset (TODO)
Expand Down Expand Up @@ -226,7 +226,7 @@ def _generic_op(
assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation {op!r}"

validate_axis(x.ndim, axis)
return generic_op(x, op, axis=axis, keep_cupy_as_array=keep_cupy_as_array, dtype=dtype)
return generic_op(x, op, axis=axis, keep_cupy_as_array=keep_cupy_as_array, dtype=dtype) # type: ignore[no-any-return]

_generic_op.__name__ = op
return cast("StatFunNoDtype | StatFunDtype", _generic_op)
Expand Down
46 changes: 31 additions & 15 deletions src/fast_array_utils/stats/_generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,50 @@
type ComplexAxis = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None


def _run_numpy_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
arr = cast("NDArray[Any] | np.number[Any] | types.CupyArray | types.CupyCOOMatrix | types.DaskArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr


@singledispatch
def generic_op(
x: CpuArray | GpuArray | DiskArray | types.DaskArray,
x: CpuArray | GpuArray | DiskArray | types.DaskArray | types.HasArrayNamespace,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
) -> Any: # Fallback handles arbitrary array-api-compatible types, so return type can't be narrowed # noqa: ANN401
Comment thread
flying-sheep marked this conversation as resolved.
Outdated
del keep_cupy_as_array
if TYPE_CHECKING:
# these are never passed to this fallback function, but `singledispatch` wants them
assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix)
# np supports these, but doesn’t know it. (TODO: test cupy)
assert not isinstance(x, types.ZarrArray | types.H5Dataset)
return cast("NDArray[Any] | np.number[Any]", _run_numpy_op(x, op, axis=axis, dtype=dtype))
# Catch array-api-compat-wrapped types that lack __array_namespace__ (i.e. PyTorch)
Comment thread
flying-sheep marked this conversation as resolved.
Outdated
import array_api_compat

if array_api_compat.is_array_api_obj(x):
xp = array_api_compat.array_namespace(x)
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op))

arr = getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op))
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr


@generic_op.register(types.HasArrayNamespace)
def _generic_op_array_api(
x: types.HasArrayNamespace,
/,
op: Ops,
*,
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> Any: # noqa: ANN401
"""Handle arrays with native array API support."""
del keep_cupy_as_array

import array_api_compat

xp = array_api_compat.array_namespace(x)
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op))


@generic_op.register(types.CupyArray | types.CupyCSMatrix)
Expand All @@ -62,7 +78,7 @@ def _generic_op_cupy(
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> types.CupyArray | np.number[Any]:
arr = cast("types.CupyArray", _run_numpy_op(x, op, axis=axis, dtype=dtype))
arr = cast("types.CupyArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
return cast("np.number[Any]", arr.get()[()]) if not keep_cupy_as_array and axis is None else arr.squeeze()


Expand Down
17 changes: 16 additions & 1 deletion src/fast_array_utils/stats/_is_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,29 @@

from numpy.typing import NDArray

# checking if all values in an array are the same

Comment thread
amalia-k510 marked this conversation as resolved.
Outdated

@singledispatch
def is_constant_(
a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray,
/,
*,
axis: Literal[0, 1] | None = None,
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray: # pragma: no cover
) -> Any: # noqa: ANN401

import array_api_compat

if array_api_compat.is_array_api_obj(a):
xp = array_api_compat.array_namespace(a)
match axis:
case None:
return bool((a == xp.reshape(a, (-1,))[0]).all())
case 0:
return is_constant_(a.T, axis=1) # reusing axis = 1
case 1:
b = xp.broadcast_to(a[:, 0:1], a.shape)
return (a == b).all(axis=1)
Comment thread
flying-sheep marked this conversation as resolved.
Outdated
raise NotImplementedError


Expand Down
17 changes: 15 additions & 2 deletions src/fast_array_utils/stats/_mean_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,26 @@ def mean_var_(
| tuple[np.float64, np.float64]
| tuple[types.DaskArray, types.DaskArray]
):
import array_api_compat

from . import mean

if isinstance(x, np.ndarray | types.CSBase):
float64 = np.float64
else:
import array_api_compat

if array_api_compat.is_array_api_obj(x):
xp = array_api_compat.array_namespace(x)
float64 = xp.float64
else:
float64 = np.float64
Comment thread
amalia-k510 marked this conversation as resolved.
Outdated

if axis is not None and isinstance(x, types.CSBase):
mean_, var = _sparse_mean_var(x, axis=axis)
else:
mean_ = mean(x, axis=axis, dtype=np.float64)
mean_sq = mean(power(x, 2, dtype=np.float64), axis=axis) if isinstance(x, types.DaskArray) else mean(power(x, 2), axis=axis, dtype=np.float64)
mean_ = mean(x, axis=axis, dtype=float64)
mean_sq = mean(power(x, 2, dtype=float64), axis=axis) if isinstance(x, types.DaskArray) else mean(power(x, 2), axis=axis, dtype=float64)
var = mean_sq - mean_**2
if correction: # R convention == 1 (unbiased estimator)
n = np.prod(x.shape) if axis is None else x.shape[axis]
Expand Down
13 changes: 11 additions & 2 deletions src/fast_array_utils/stats/_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@


if TYPE_CHECKING:
from typing import Any

from numpy.typing import DTypeLike

from fast_array_utils.typing import CpuArray, GpuArray
Expand All @@ -21,13 +23,20 @@
def power[Arr: Array](x: Arr, n: int, /, dtype: DTypeLike | None = None) -> Arr:
"""Take array or matrix to a power."""
# This wrapper is necessary because TypeVars can’t be used in `singledispatch` functions
return _power(x, n, dtype=dtype) # type: ignore[return-value]
return _power(x, n, dtype=dtype) # type: ignore[no-any-return]


@singledispatch
def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array:
def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Any: # noqa: ANN401
if TYPE_CHECKING:
assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix)

import array_api_compat

if array_api_compat.is_array_api_obj(x):
xp = array_api_compat.array_namespace(x)
return xp.pow(x, n) if dtype is None else xp.pow(xp.astype(x, dtype), n)

return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator]


Expand Down
11 changes: 10 additions & 1 deletion src/fast_array_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol, runtime_checkable


__all__ = [
Expand Down Expand Up @@ -37,6 +37,8 @@

# scipy sparse
if TYPE_CHECKING:
from types import ModuleType

from scipy.sparse import coo_array, coo_matrix, csc_array, csc_matrix, csr_array, csr_matrix, sparray, spmatrix
else:
try: # cs?_array isn’t available in older scipy versions
Expand Down Expand Up @@ -116,3 +118,10 @@
CSRDataset.__module__ = CSCDataset.__module__ = "anndata.abc"
CSDataset = CSRDataset | CSCDataset
"""Anndata sparse out-of-core matrices."""


@runtime_checkable
class HasArrayNamespace(Protocol):
"""An array object compatible with the Python array API standard."""

def __array_namespace__(self, /, *, api_version: str | None = None) -> ModuleType: ...
121 changes: 121 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING

import numpy as np
import pytest

from fast_array_utils import stats
from fast_array_utils.conv import to_dense


if TYPE_CHECKING:
from typing import Any, Literal

pytestmark = pytest.mark.skipif(not find_spec("jax"), reason="jax not installed")

if find_spec("jax"):
# enabling 64-bit precision in JAX as it defaults to 32-bit only
# problem as mean_var passes dtype= np.float64 internally, which crashes without this fix
import jax

jax.config.update("jax_enable_x64", True) # noqa: FBT003


@pytest.fixture
def jax_arr() -> Any: # noqa: ANN401
import jax.numpy as jnp

return jnp.array([[1, 0], [2, 0], [3, 0]], dtype=jnp.float32)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_sum(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
import jax.numpy as jnp

result = stats.sum(jax_arr, axis=axis)
expected = jnp.sum(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_min(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
import jax.numpy as jnp

result = stats.min(jax_arr, axis=axis)
expected = jnp.min(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_max(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
import jax.numpy as jnp

result = stats.max(jax_arr, axis=axis)
expected = jnp.max(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_mean(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
import jax.numpy as jnp

result = stats.mean(jax_arr, axis=axis)
expected = jnp.mean(jax_arr, axis=axis)
assert jnp.allclose(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_is_constant(axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

x = jnp.array(
[
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 0],
],
dtype=jnp.float32,
)
result = stats.is_constant(x, axis=axis)

if axis is None:
assert bool(result) is False
elif axis == 0:
expected = jnp.array([True, True, False, False])
assert jnp.array_equal(result, expected)
else:
expected = jnp.array([False, False, True, True, False, True])
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_mean_var(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
import jax.numpy as jnp

mean, var = stats.mean_var(jax_arr, axis=axis, correction=1)

mean_expected = jnp.mean(jax_arr, axis=axis)
n = jax_arr.size if axis is None else jax_arr.shape[axis]
var_expected = jnp.var(jax_arr, axis=axis) * n / (n - 1)

assert jnp.allclose(mean, mean_expected)
assert jnp.allclose(var, var_expected)


def test_to_dense(jax_arr: Any) -> None: # noqa: ANN401
import jax.numpy as jnp

result = to_dense(jax_arr)
assert jnp.array_equal(result, jax_arr)


def test_to_dense_to_cpu(jax_arr: Any) -> None: # noqa: ANN401
result = to_dense(jax_arr, to_cpu_memory=True)
assert isinstance(result, np.ndarray)
np.testing.assert_array_equal(result, np.asarray(jax_arr))
Loading
Loading