Skip to content

Fix xla_pmap_p import for JAX versions that removed pmap#2173

Open
saitcakmak wants to merge 1 commit intopyro-ppl:masterfrom
saitcakmak:fix-xla-pmap-p-import
Open

Fix xla_pmap_p import for JAX versions that removed pmap#2173
saitcakmak wants to merge 1 commit intopyro-ppl:masterfrom
saitcakmak:fix-xla-pmap-p-import

Conversation

@saitcakmak
Copy link
Copy Markdown

Summary

  • JAX 0.10.0 removed the C++ pmap infrastructure (including xla_pmap_p). This causes an ImportError when importing numpyro with newer JAX versions.
  • Guard the xla_pmap_p import with a try/except and skip the provenance tracking rule registration when it's unavailable.

Test plan

  • python -c "from numpyro.ops.provenance import eval_provenance" succeeds
  • eval_provenance works end-to-end (tested with lambda x, y, z: x + y)
  • All 10 tests in test/ops/test_provenance.py pass

JAX removed the C++ pmap infrastructure (including xla_pmap_p) in a
recent release. Guard the import so numpyro works with both old and
new JAX versions.
@Qazalbash
Copy link
Copy Markdown
Collaborator

Thanks @saitcakmak

@Qazalbash Qazalbash requested a review from juanitorduz April 16, 2026 14:23
@juanitorduz
Copy link
Copy Markdown
Collaborator

juanitorduz commented Apr 16, 2026

LGTM.

It seems the failing tests in Python 3.14 CI (test-inference) are not caused by this PR. They fail inside funsor, a transitive dependency:

  .venv/lib/python3.14/site-packages/funsor/jax/ops.py:203                                                                                                                                                           
  E       TypeError: clip() got an unexpected keyword argument 'a_max' 

funsor calls jnp.clip(..., a_max=...), but recent JAX releases dropped the deprecated a_max/a_min kwargs (replaced by max/min, aligning with NumPy 2.x).

I am looking into the upstream issue: pyro-ppl/funsor#611 @fehiepsi would you mind taking a look?

Qazalbash added a commit to kokabsc/gwkokab that referenced this pull request Apr 18, 2026
sethaxen added a commit to sethaxen/CAGPJax that referenced this pull request Apr 20, 2026
* chore: Temporarily upperbound jax

Until numpyro v0.20 is compatible with jax v0.10, pyro-ppl/numpyro#2173

* fix: Make support a property

Change made in numpyro v0.20.0

* fix: Explicitly import jax.test_util
conorheins added a commit to infer-actively/pymdp that referenced this pull request Apr 21, 2026
setup.cfg is retained for install paths (e.g. Jetson) that don't read
pyproject.toml, so the upper bound has to live in both places to stay
consistent. Same rationale as the pyproject.toml pin — tracks
pyro-ppl/numpyro#2173.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
tedvr pushed a commit to tedvr/pymdp that referenced this pull request Apr 24, 2026
jax 0.10.0 removed xla_pmap_p from jax.extend.core.primitives, which
breaks numpyro's provenance module and thus any pybefit-using test.
Upstream fix is pyro-ppl/numpyro#2173 (approved, not yet released).
Revert this pin once a patched numpyro ships.

Fixes infer-actively#389.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

JAX 0.10.0 breaks NumPyro

4 participants