-
Notifications
You must be signed in to change notification settings - Fork 3
Adding array-api-compat fallback #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
flying-sheep
merged 43 commits into
scverse:main
from
amalia-k510:array-api-implementation
Apr 30, 2026
Merged
Changes from 25 commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
64f5304
array-api initially implementation
amalia-k510 5373050
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f016c39
updates in regards to the handler and some array_api handling fixes
amalia-k510 cca3ad6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c1bb155
Issues with jax test are fixed, introduced similar tests with pytorch
amalia-k510 c9a8f85
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 bd08c2e
pre-commit fixes
amalia-k510 b2e3f9b
mipy issues fix
amalia-k510 65e83a6
speed up fix
amalia-k510 2a1924d
Update src/fast_array_utils/stats/_mean_var.py
amalia-k510 9684213
Addressed the comments
amalia-k510 1e3c296
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 63c5e16
chore: simplify
flying-sheep 9c8466a
addressing comments about removing is_array_api_obj check
amalia-k510 86a7503
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 e28f176
import fix
amalia-k510 c704259
ignore comments update and mypy test
amalia-k510 ab6e200
Merge branch 'main' into array-api-implementation
amalia-k510 aaccbda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7e5102a
main version
amalia-k510 eed16a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 567107a
residues ignore comments removed
amalia-k510 1e41d24
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 aab50f7
pyproject, jax optional dependencies
amalia-k510 91ef896
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0f62abc
commented addressed, mypy try again
amalia-k510 d642668
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 37a6634
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] bab4392
mypy comment add
amalia-k510 d22f74b
Merge branch 'array-api-implementation' of https://github.com/amalia-…
amalia-k510 3d4ee3a
types
flying-sheep c77a1bc
types for others
amalia-k510 6d3891e
types, missing parameters
amalia-k510 9e2a9fe
revert pyproject.toml
amalia-k510 57117d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0415fd5
rework deps
flying-sheep 3857917
Merge branch 'main' into array-api-implementation
flying-sheep 592b7b7
fix deps
flying-sheep c890bc3
fix types
flying-sheep 49283b0
fix cupy tests
flying-sheep f6463cc
fix disk array
flying-sheep 474c969
fmt
flying-sheep 5de4e5b
coverage
flying-sheep File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| # SPDX-License-Identifier: MPL-2.0 | ||
| from __future__ import annotations | ||
|
|
||
| from importlib.util import find_spec | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from fast_array_utils import stats | ||
| from fast_array_utils.conv import to_dense | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from typing import Any, Literal | ||
|
|
||
| pytestmark = pytest.mark.skipif(not find_spec("jax"), reason="jax not installed") | ||
|
|
||
| if find_spec("jax"): | ||
| # enabling 64-bit precision in JAX as it defaults to 32-bit only | ||
| # problem as mean_var passes dtype= np.float64 internally, which crashes without this fix | ||
| import jax | ||
|
|
||
| jax.config.update("jax_enable_x64", True) # noqa: FBT003 | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def jax_arr() -> Any: # noqa: ANN401 | ||
| import jax.numpy as jnp | ||
|
|
||
| return jnp.array([[1, 0], [2, 0], [3, 0]], dtype=jnp.float32) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("axis", [None, 0, 1]) | ||
| def test_sum(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401 | ||
| import jax.numpy as jnp | ||
|
|
||
| result = stats.sum(jax_arr, axis=axis) | ||
| expected = jnp.sum(jax_arr, axis=axis) | ||
| assert jnp.array_equal(result, expected) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("axis", [None, 0, 1]) | ||
| def test_min(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401 | ||
| import jax.numpy as jnp | ||
|
|
||
| result = stats.min(jax_arr, axis=axis) | ||
| expected = jnp.min(jax_arr, axis=axis) | ||
| assert jnp.array_equal(result, expected) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("axis", [None, 0, 1]) | ||
| def test_max(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401 | ||
| import jax.numpy as jnp | ||
|
|
||
| result = stats.max(jax_arr, axis=axis) | ||
| expected = jnp.max(jax_arr, axis=axis) | ||
| assert jnp.array_equal(result, expected) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("axis", [None, 0, 1]) | ||
| def test_mean(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401 | ||
| import jax.numpy as jnp | ||
|
|
||
| result = stats.mean(jax_arr, axis=axis) | ||
| expected = jnp.mean(jax_arr, axis=axis) | ||
| assert jnp.allclose(result, expected) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("axis", [None, 0, 1]) | ||
| def test_is_constant(axis: Literal[0, 1] | None) -> None: | ||
| import jax.numpy as jnp | ||
|
|
||
| x = jnp.array( | ||
| [ | ||
| [0, 0, 1, 1], | ||
| [0, 0, 1, 1], | ||
| [0, 0, 0, 0], | ||
| [0, 0, 0, 0], | ||
| [0, 0, 1, 0], | ||
| [0, 0, 0, 0], | ||
| ], | ||
| dtype=jnp.float32, | ||
| ) | ||
| result = stats.is_constant(x, axis=axis) | ||
|
|
||
| if axis is None: | ||
| assert bool(result) is False | ||
| elif axis == 0: | ||
| expected = jnp.array([True, True, False, False]) | ||
| assert jnp.array_equal(result, expected) | ||
| else: | ||
| expected = jnp.array([False, False, True, True, False, True]) | ||
| assert jnp.array_equal(result, expected) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("axis", [None, 0, 1]) | ||
| def test_mean_var(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401 | ||
| import jax.numpy as jnp | ||
|
|
||
| mean, var = stats.mean_var(jax_arr, axis=axis, correction=1) | ||
|
|
||
| mean_expected = jnp.mean(jax_arr, axis=axis) | ||
| n = jax_arr.size if axis is None else jax_arr.shape[axis] | ||
| var_expected = jnp.var(jax_arr, axis=axis) * n / (n - 1) | ||
|
|
||
| assert jnp.allclose(mean, mean_expected) | ||
| assert jnp.allclose(var, var_expected) | ||
|
|
||
|
|
||
| def test_to_dense(jax_arr: Any) -> None: # noqa: ANN401 | ||
| import jax.numpy as jnp | ||
|
|
||
| result = to_dense(jax_arr) | ||
| assert jnp.array_equal(result, jax_arr) | ||
|
|
||
|
|
||
| def test_to_dense_to_cpu(jax_arr: Any) -> None: # noqa: ANN401 | ||
| result = to_dense(jax_arr, to_cpu_memory=True) | ||
| assert isinstance(result, np.ndarray) | ||
| np.testing.assert_array_equal(result, np.asarray(jax_arr)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.