Adding array-api-compat fallback#159
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #159 +/- ##
==========================================
+ Coverage 99.29% 99.33% +0.03%
==========================================
Files 21 21
Lines 566 598 +32
==========================================
+ Hits 562 594 +32
Misses 4 4 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Merging this PR will degrade performance by 32.86%
|
| Benchmark | BASE |
HEAD |
Efficiency | |
|---|---|---|---|---|
| 👁 | test_stats_benchmark[scipy.sparse.csc_array-2d-ax0-float64-is_constant] |
2.9 ms | 4.3 ms | -32.86% |
| 👁 | test_stats_benchmark[scipy.sparse.csr_array-2d-ax1-int32-is_constant] |
2.8 ms | 3.2 ms | -10.17% |
Comparing amalia-k510:array-api-implementation (5de4e5b) with main (a54356d)
flying-sheep
left a comment
There was a problem hiding this comment.
good start. I wonder if adding a ArrayAPIObject protocol that checks for __array_namespace__ or so could be used instead of putting the array api stuff in the @singledispatch fallback body. What did we do in the AnnData PR?
for more information, see https://pre-commit.ci
…k510/fast-array-utils into array-api-implementation
There was a problem hiding this comment.
OK, before we move on, I need to understand this comment:
Catch array-api-compat-wrapped types that lack
__array_namespace__(i.e. PyTorch)
Once I do, I can form an actual opinion about how I’d like this to look (all my comments about an ABC below rely on this). I commented on the line of the comment below so we keep that in a subthread.
Co-authored-by: Philipp A. <flying-sheep@web.de>
for more information, see https://pre-commit.ci
|
OK, so basically torch doesn’t actually support array API yet (see #159 (comment)), so all the fallback code is just for torch. I don’t think torch should be part of this PR then, we should think about it separately if we want to support it. So please
Regarding the benchmarks: I’m not so sure if this is just static overhead … I think your |
…k510/fast-array-utils into array-api-implementation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…k510/fast-array-utils into array-api-implementation
for more information, see https://pre-commit.ci
|
I am still running into mypy issues. It's failing on |
…k510/fast-array-utils into array-api-implementation
for more information, see https://pre-commit.ci
…k510/fast-array-utils into array-api-implementation
|
OK, I started with the typing. As you can see, adding an overload to |
Thanks for the example! I'd actually like to take that on myself if that's okay. |
This PR adds
array-api-compatas a fallback in thesingledispatchfunctions across the stats and conv modules so that Array API-compatible arrays (JAX, PyTorch, and others) work out of the box without needing to register each backend individually. The approach is: in each fallback, tryarray_api_compat.array_namespace(x)first. If it recognizes the array, dispatch through the standard Array API; if not, fall through to the existing numpy path. This touches_generic_ops.py(sum/min/max),_is_constant.py,_power.py, andconv/_to_dense.py.array-api-compatis added as a dependency inpyproject.toml.Tests are in
tests/test_jax.pycovering all the stats functions andto_densewith JAX arrays. One thing to flag:_mean_var.pydidn't need changes since it callsmean()andpower()internally which already go through the fixed dispatchers, at least in my understanding, but JAX requiresjax_enable_x64for thedtype=np.float64calls to work.