Skip to content

Adding array-api-compat fallback#159

Merged
flying-sheep merged 43 commits intoscverse:mainfrom
amalia-k510:array-api-implementation
Apr 30, 2026
Merged

Adding array-api-compat fallback#159
flying-sheep merged 43 commits intoscverse:mainfrom
amalia-k510:array-api-implementation

Conversation

@amalia-k510
Copy link
Copy Markdown
Contributor

This PR adds array-api-compat as a fallback in the singledispatch functions across the stats and conv modules so that Array API-compatible arrays (JAX, PyTorch, and others) work out of the box without needing to register each backend individually. The approach is: in each fallback, try array_api_compat.array_namespace(x) first. If it recognizes the array, dispatch through the standard Array API; if not, fall through to the existing numpy path. This touches _generic_ops.py (sum/min/max), _is_constant.py, _power.py, and conv/_to_dense.py. array-api-compat is added as a dependency in pyproject.toml.

Tests are in tests/test_jax.py covering all the stats functions and to_dense with JAX arrays. One thing to flag: _mean_var.py didn't need changes since it calls mean() and power() internally which already go through the fixed dispatchers, at least in my understanding, but JAX requires jax_enable_x64 for the dtype=np.float64 calls to work.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 23, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 99.33%. Comparing base (a54356d) to head (5de4e5b).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #159      +/-   ##
==========================================
+ Coverage   99.29%   99.33%   +0.03%     
==========================================
  Files          21       21              
  Lines         566      598      +32     
==========================================
+ Hits          562      594      +32     
  Misses          4        4              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@codspeed-hq
Copy link
Copy Markdown

codspeed-hq Bot commented Mar 23, 2026

Merging this PR will degrade performance by 32.86%

⚠️ Different runtime environments detected

Some benchmarks with significant performance changes were compared across different runtime environments,
which may affect the accuracy of the results.

Open the report in CodSpeed to investigate

❌ 2 (👁 2) regressed benchmarks
✅ 230 untouched benchmarks

Performance Changes

Benchmark BASE HEAD Efficiency
👁 test_stats_benchmark[scipy.sparse.csc_array-2d-ax0-float64-is_constant] 2.9 ms 4.3 ms -32.86%
👁 test_stats_benchmark[scipy.sparse.csr_array-2d-ax1-int32-is_constant] 2.8 ms 3.2 ms -10.17%

Comparing amalia-k510:array-api-implementation (5de4e5b) with main (a54356d)

Open in CodSpeed

Copy link
Copy Markdown
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good start. I wonder if adding a ArrayAPIObject protocol that checks for __array_namespace__ or so could be used instead of putting the array api stuff in the @singledispatch fallback body. What did we do in the AnnData PR?

Comment thread src/fast_array_utils/conv/_to_dense.py Outdated
Comment thread src/fast_array_utils/stats/_generic_ops.py Outdated
Comment thread src/fast_array_utils/stats/_generic_ops.py Outdated
Comment thread 03_23_2026.log Outdated
Copy link
Copy Markdown
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, before we move on, I need to understand this comment:

Catch array-api-compat-wrapped types that lack __array_namespace__ (i.e. PyTorch)

Once I do, I can form an actual opinion about how I’d like this to look (all my comments about an ABC below rely on this). I commented on the line of the comment below so we keep that in a subthread.

Comment thread src/fast_array_utils/stats/_generic_ops.py Outdated
Comment thread src/fast_array_utils/stats/_generic_ops.py Outdated
Comment thread pyproject.toml Outdated
Comment thread src/fast_array_utils/stats/_is_constant.py Outdated
Comment thread src/fast_array_utils/stats/_is_constant.py Outdated
Comment thread src/fast_array_utils/stats/_mean_var.py Outdated
@flying-sheep
Copy link
Copy Markdown
Member

flying-sheep commented Apr 20, 2026

OK, so basically torch doesn’t actually support array API yet (see #159 (comment)), so all the fallback code is just for torch. I don’t think torch should be part of this PR then, we should think about it separately if we want to support it.

So please

  1. move the torch parts (i.e. the fallback code and the torch tests) out if this PR
  2. so wherever there is only a if array_api_compat.is_array_api_obj branch and no @register(... | HasArrayNamespace) branch, create that register
  3. revert the Any changes. I can fix the types but essentially unions will gain a | HasArrayNamespace member and all touched functions gain an overload like def func[A: HasArrayNamespace](x: A, ...) -> A: ...

Regarding the benchmarks: I’m not so sure if this is just static overhead … I think your xp.pow(xp.astype(x, dtype), n) is just genuinely slower than the branch numpy code took before. Can you improve that?

@amalia-k510
Copy link
Copy Markdown
Contributor Author

amalia-k510 commented Apr 27, 2026

I am still running into mypy issues. It's failing on import jax since jax isn't in the mypy environment. Also, from what I understand, I can't add a mypy override either since the config was removed from main. What would be the best way to handle it? @flying-sheep

Copy link
Copy Markdown
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, this is looking very close to perfect! I’ll take a look at the type issues

I can't add a mypy override either since the config was removed from main

what do you mean? the mypy dependencies are directly in .pre-commit-config.yaml (sadly)

Comment thread src/fast_array_utils/stats/_generic_ops.py Outdated
Comment thread src/fast_array_utils/stats/_power.py Outdated
Comment thread pyproject.toml Outdated
@flying-sheep
Copy link
Copy Markdown
Member

OK, I started with the typing. As you can see, adding an overload to sum makes sum work where it’s called. This should be done everywhere. I can do it, as you wish.

@amalia-k510
Copy link
Copy Markdown
Contributor Author

OK, I started with the typing. As you can see, adding an overload to sum makes sum work where it’s called. This should be done everywhere. I can do it, as you wish.

Thanks for the example! I'd actually like to take that on myself if that's okay.

@flying-sheep flying-sheep added the run-gpu-ci Apply this label to run GPU CI once label Apr 30, 2026
@flying-sheep flying-sheep marked this pull request as ready for review April 30, 2026 10:03
@flying-sheep flying-sheep enabled auto-merge (squash) April 30, 2026 15:21
@flying-sheep flying-sheep merged commit febaf24 into scverse:main Apr 30, 2026
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-gpu-ci Apply this label to run GPU CI once

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants