Skip to content

[jax-0.9.1.dev] migrate deprecated Pallas to modern indexing#9

Open
judyzhaoxinwu wants to merge 38 commits into
Borklet-Labs:jax_0.9.1_devfrom
judyzhaoxinwu:jw/fix-pallas-nonetype
Open

[jax-0.9.1.dev] migrate deprecated Pallas to modern indexing#9
judyzhaoxinwu wants to merge 38 commits into
Borklet-Labs:jax_0.9.1_devfrom
judyzhaoxinwu:jw/fix-pallas-nonetype

Conversation

@judyzhaoxinwu
Copy link
Copy Markdown

Description Suggestion

This PR aligns AXLearn with JAX 0.9.1.dev by migrating legacy Pallas API calls to the modern indexing syntax.

Core Changes
Pallas Modernization: Replaced deprecated pl.load, pl.store, and pl.swap with direct reference indexing (e.g., x = x_ref[...], o_ref[...] = y).

andersensam and others added 30 commits January 30, 2026 21:19
…guments before tracing. Corrects get_kernel_name signature mismatch in tpu_splash_attention.py by using keyword arguments for metadata.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants