Re-enable JAX 0.10.0 with MoE layout-constraint fix#2683
Merged
Conversation
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
Collaborator
|
@guowei-dev: fyi, after this pr has been merged, the reverted sparse core kernel pr can be merged again: #2634 |
lk-chen
approved these changes
May 20, 2026
3 tasks
This reverts commit 60cd862. Signed-off-by: Qiliang Cui <derrhein@gmail.com>
…ax 0.10 The with_layout_constraint(weight, Layout(...)) calls in process_moe_weights were a workaround for a jax-0.9.x "must have valid byte strides" error. On jax 0.10 the propagation path now reads layout_constraint_p in get_out_layouts_via_propagation and actually pins the requested layout to the jit output — which is the WRONG layout for the MoE kernel's weight reads and silently miscomputes: gemma-4-26B-A4B-it GSM8K accuracy drops from 0.41 to 0.0, and individual prompts return garbage. The MoE kernel can use the natural C-order layout that the auto-layout pass picks; no constraint is needed on jax 0.10. Removing the three call sites (unquantized weight processing + fused-moe weight reshape) fixes gemma-4 accuracy without reverting jax. All other MoE quantization paths route through the same process_moe_weights, so they all benefit from the fix. Verified: - Before: gemma-4-26B-A4B-it returns garbage tokens on jax 0.10.0. - After: same model returns "4" (2+2), "56" (7x8), coherent text on a haiku prompt. Signed-off-by: Qiliang Cui <derrhein@gmail.com>
0f619cb to
a801dde
Compare
ylangtsou
pushed a commit
to ylangtsou/tpu-inference
that referenced
this pull request
May 21, 2026
Signed-off-by: Qiliang Cui <derrhein@gmail.com> Co-authored-by: Qiliang Cui <derrhein@gmail.com>
jyj0w0
pushed a commit
that referenced
this pull request
May 25, 2026
Signed-off-by: Qiliang Cui <derrhein@gmail.com> Co-authored-by: Qiliang Cui <derrhein@gmail.com>
patrickji2014
pushed a commit
that referenced
this pull request
May 27, 2026
Signed-off-by: Qiliang Cui <derrhein@gmail.com> Co-authored-by: Qiliang Cui <derrhein@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Reapplies the JAX 0.10.0 upgrade (originally #2291) that was reverted in #2648, and bundles the MoE layout-constraint fix from #2665 so the upgrade is safe.
Why this is two commits
Reapply "Update JAX to 0.10.0" (#2648): reverts the revert, restoring all JAX 0.10.0 changes (requirements.txt, sparse-core kernels, related tests).[MoE] Remove with_layout_constraint that silently miscomputes under jax 0.10: the actual fix.The bug that motivated the original revert
After #2291 landed, gemma-4-26B-A4B-it started returning garbage tokens on JAX 0.10.0 (GSM8K accuracy 0.41 → 0.0). #2648 reverted the JAX upgrade to recover accuracy. The root cause was localized to MoE weight processing — not JAX itself.
Root cause
The
with_layout_constraint(weight, Layout(...))calls inprocess_moe_weightswere a workaround for a jax-0.9.xmust have valid byte strideserror. On jax 0.10 the propagation path now readslayout_constraint_pinget_out_layouts_via_propagationand actually pins the requested layout to the jit output — which is the wrong layout for the MoE kernel's weight reads and silently miscomputes.Concretely:
Fix
Remove three
with_layout_constraintcall sites intpu_inference/layers/common/process_weights/moe_weights.py(unquantized weight processing + fused-moe weight reshape) and drop the now-unused import. The MoE kernel can use the natural C-order layout the auto-layout pass picks; no constraint is needed on jax 0.10. All other MoE quantization paths route through the sameprocess_moe_weights, so they all benefit.Verification (from #2665)
4(2+2),56(7x8), coherent text on a haiku prompt.Perf comparison on the v0.21.0 nightly (builds 17833 baseline vs 18033 with the fix):
i.e. on jax 0.10, auto-layout picks an equally-good (or better) layout for the MoE GMM kernel; the manual constraint is no longer earning its keep.
Related
Test plan