Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
64f5304
array-api initially implementation
amalia-k510 Mar 23, 2026
5373050
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2026
f016c39
updates in regards to the handler and some array_api handling fixes
amalia-k510 Apr 15, 2026
cca3ad6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2026
c1bb155
Issues with jax test are fixed, introduced similar tests with pytorch
amalia-k510 Apr 15, 2026
c9a8f85
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 15, 2026
bd08c2e
pre-commit fixes
amalia-k510 Apr 15, 2026
b2e3f9b
mipy issues fix
amalia-k510 Apr 15, 2026
65e83a6
speed up fix
amalia-k510 Apr 15, 2026
2a1924d
Update src/fast_array_utils/stats/_mean_var.py
amalia-k510 Apr 17, 2026
9684213
Addressed the comments
amalia-k510 Apr 17, 2026
1e3c296
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2026
63c5e16
chore: simplify
flying-sheep Apr 20, 2026
9c8466a
addressing comments about removing is_array_api_obj check
amalia-k510 Apr 27, 2026
86a7503
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
e28f176
import fix
amalia-k510 Apr 27, 2026
c704259
ignore comments update and mypy test
amalia-k510 Apr 27, 2026
ab6e200
Merge branch 'main' into array-api-implementation
amalia-k510 Apr 27, 2026
aaccbda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
7e5102a
main version
amalia-k510 Apr 27, 2026
eed16a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
567107a
residues ignore comments removed
amalia-k510 Apr 27, 2026
1e41d24
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
aab50f7
pyproject, jax optional dependencies
amalia-k510 Apr 27, 2026
91ef896
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
0f62abc
commented addressed, mypy try again
amalia-k510 Apr 27, 2026
d642668
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
37a6634
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2026
bab4392
mypy comment add
amalia-k510 Apr 27, 2026
d22f74b
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 Apr 27, 2026
3d4ee3a
types
flying-sheep Apr 27, 2026
c77a1bc
types for others
amalia-k510 Apr 29, 2026
6d3891e
types, missing parameters
amalia-k510 Apr 29, 2026
9e2a9fe
revert pyproject.toml
amalia-k510 Apr 29, 2026
57117d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
0415fd5
rework deps
flying-sheep Apr 30, 2026
3857917
Merge branch 'main' into array-api-implementation
flying-sheep Apr 30, 2026
592b7b7
fix deps
flying-sheep Apr 30, 2026
c890bc3
fix types
flying-sheep Apr 30, 2026
49283b0
fix cupy tests
flying-sheep Apr 30, 2026
f6463cc
fix disk array
flying-sheep Apr 30, 2026
474c969
fmt
flying-sheep Apr 30, 2026
5de4e5b
coverage
flying-sheep Apr 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,361 changes: 3,361 additions & 0 deletions 03_23_2026.log
Comment thread
flying-sheep marked this conversation as resolved.
Outdated

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ classifiers = [
"Programming Language :: Python :: 3.14",
]
dynamic = [ "description", "readme", "version" ]
dependencies = [ "numpy>=2" ]
dependencies = [ "array-api-compat", "numpy>=2" ]
optional-dependencies.accel = [ "numba>=0.57" ]
optional-dependencies.dask = [ "dask>=2023.6.1" ]
optional-dependencies.full = [ "fast-array-utils[accel,dask,sparse]", "h5py", "zarr" ]
optional-dependencies.jax = [ "jax", "jaxlib" ]
optional-dependencies.sparse = [ "scipy>=1.13" ]
optional-dependencies.testing = [ "packaging" ]
urls."Issue Tracker" = "https://github.com/scverse/fast-array-utils/issues"
Expand All @@ -48,6 +49,15 @@ doc = [
"sphinx>=9.0.1",
"sphinx-autofixture>=0.4.1",
]
# for update-mypy-hook
mypy = [
"fast-array-utils[full]",
"scipy-stubs",
# TODO: replace sphinx with this: { include-group = "doc" },
"sphinx",
"types-docutils",
{ include-group = "test" },
]
Comment thread
amalia-k510 marked this conversation as resolved.
Outdated
test-min = [
"coverage[toml]",
"fast-array-utils[sparse,testing]", # include sparse for testing numba-less to_dense
Expand Down
7 changes: 7 additions & 0 deletions src/fast_array_utils/conv/_to_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def to_dense_(
order: Literal["K", "A", "C", "F"] = "K",
to_cpu_memory: bool = False,
) -> NDArray[Any] | types.CupyArray | types.DaskArray:
import array_api_compat

if not isinstance(x, np.ndarray) and array_api_compat.is_array_api_obj(x):
if to_cpu_memory:
return np.asarray(x, order=order)
return x # already dense
Comment thread
flying-sheep marked this conversation as resolved.
Outdated

del to_cpu_memory # it already is
return np.asarray(x, order=order)

Expand Down
9 changes: 8 additions & 1 deletion src/fast_array_utils/stats/_generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,20 @@ def generic_op(
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: # switch to Any later
Comment thread
flying-sheep marked this conversation as resolved.
Outdated
del keep_cupy_as_array
if TYPE_CHECKING:
# these are never passed to this fallback function, but `singledispatch` wants them
assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix)
# np supports these, but doesn’t know it. (TODO: test cupy)
assert not isinstance(x, types.ZarrArray | types.H5Dataset)

# doing array_api_compat first
import array_api_compat

if array_api_compat.is_array_api_obj(x):
xp = array_api_compat.array_namespace(x)
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op))
Comment thread
flying-sheep marked this conversation as resolved.
Outdated
return cast("NDArray[Any] | np.number[Any]", _run_numpy_op(x, op, axis=axis, dtype=dtype))


Expand Down
17 changes: 16 additions & 1 deletion src/fast_array_utils/stats/_is_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,29 @@

from numpy.typing import NDArray

# checking if all values in an array are the same

Comment thread
amalia-k510 marked this conversation as resolved.
Outdated

@singledispatch
def is_constant_(
a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray,
/,
*,
axis: Literal[0, 1] | None = None,
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray: # pragma: no cover
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray: # switch to Any later

import array_api_compat

if array_api_compat.is_array_api_obj(a):
xp = array_api_compat.array_namespace(a)
match axis:
case None:
return bool((a == xp.reshape(a, (-1,))[0]).all())
case 0:
return is_constant_(a.T, axis=1) # reusing axis = 1
case 1:
b = xp.broadcast_to(a[:, 0:1], a.shape)
return (a == b).all(axis=1)
Comment thread
flying-sheep marked this conversation as resolved.
Outdated
raise NotImplementedError


Expand Down
7 changes: 7 additions & 0 deletions src/fast_array_utils/stats/_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def power[Arr: Array](x: Arr, n: int, /, dtype: DTypeLike | None = None) -> Arr:
def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array:
if TYPE_CHECKING:
assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix)

import array_api_compat

if array_api_compat.is_array_api_obj(x):
xp = array_api_compat.array_namespace(x)
return xp.pow(x, n) if dtype is None else xp.pow(x.astype(dtype), n)

return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator]


Expand Down
121 changes: 121 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING

import numpy as np
import pytest

from fast_array_utils import stats
from fast_array_utils.conv import to_dense


if TYPE_CHECKING:
from typing import Any, Literal

pytestmark = pytest.mark.skipif(not find_spec("jax"), reason="jax not installed")

if find_spec("jax"):
# enabling 64-bit precision in JAX as it defaults to 32-bit only
# problem as mean_var passes dtype= np.float64 internally, which crashes without this fix
import jax

jax.config.update("jax_enable_x64", True)


@pytest.fixture
def jax_arr() -> Any:
import jax.numpy as jnp

return jnp.array([[1, 0], [2, 0], [3, 0]], dtype=jnp.float32)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_sum(jax_arr: Any, axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

result = stats.sum(jax_arr, axis=axis)
expected = jnp.sum(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_min(jax_arr: Any, axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

result = stats.min(jax_arr, axis=axis)
expected = jnp.min(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_max(jax_arr: Any, axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

result = stats.max(jax_arr, axis=axis)
expected = jnp.max(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_mean(jax_arr: Any, axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

result = stats.mean(jax_arr, axis=axis)
expected = jnp.mean(jax_arr, axis=axis)
assert jnp.allclose(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_is_constant(axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

x = jnp.array(
[
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 0],
],
dtype=jnp.float32,
)
result = stats.is_constant(x, axis=axis)

if axis is None:
assert bool(result) is False
elif axis == 0:
expected = jnp.array([True, True, False, False])
assert jnp.array_equal(result, expected)
else:
expected = jnp.array([False, False, True, True, False, True])
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_mean_var(jax_arr: Any, axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

mean, var = stats.mean_var(jax_arr, axis=axis, correction=1)

mean_expected = jnp.mean(jax_arr, axis=axis)
n = jax_arr.size if axis is None else jax_arr.shape[axis]
var_expected = jnp.var(jax_arr, axis=axis) * n / (n - 1)

assert jnp.allclose(mean, mean_expected)
assert jnp.allclose(var, var_expected)


def test_to_dense(jax_arr: Any) -> None:
import jax.numpy as jnp

result = to_dense(jax_arr)
assert jnp.array_equal(result, jax_arr)


def test_to_dense_to_cpu(jax_arr: Any) -> None:
result = to_dense(jax_arr, to_cpu_memory=True)
assert isinstance(result, np.ndarray)
np.testing.assert_array_equal(result, np.asarray(jax_arr))
Loading