Skip to content

Re-enable JAX 0.10.0 with MoE layout-constraint fix#2683

Merged
QiliangCui merged 2 commits into
mainfrom
cuiq-jax010-with-moe-fix
May 21, 2026
Merged

Re-enable JAX 0.10.0 with MoE layout-constraint fix#2683
QiliangCui merged 2 commits into
mainfrom
cuiq-jax010-with-moe-fix

Conversation

@QiliangCui2023
Copy link
Copy Markdown
Collaborator

Summary

Reapplies the JAX 0.10.0 upgrade (originally #2291) that was reverted in #2648, and bundles the MoE layout-constraint fix from #2665 so the upgrade is safe.

Why this is two commits

  • Commit 1 — Reapply "Update JAX to 0.10.0" (#2648): reverts the revert, restoring all JAX 0.10.0 changes (requirements.txt, sparse-core kernels, related tests).
  • Commit 2 — [MoE] Remove with_layout_constraint that silently miscomputes under jax 0.10: the actual fix.

The bug that motivated the original revert

After #2291 landed, gemma-4-26B-A4B-it started returning garbage tokens on JAX 0.10.0 (GSM8K accuracy 0.41 → 0.0). #2648 reverted the JAX upgrade to recover accuracy. The root cause was localized to MoE weight processing — not JAX itself.

Root cause

The with_layout_constraint(weight, Layout(...)) calls in process_moe_weights were a workaround for a jax-0.9.x must have valid byte strides error. On jax 0.10 the propagation path now reads layout_constraint_p in get_out_layouts_via_propagation and actually pins the requested layout to the jit output — which is the wrong layout for the MoE kernel's weight reads and silently miscomputes.

Concretely:

  • gemma-4-26B-A4B-it GSM8K accuracy drops from 0.41 to 0.0
  • Individual prompts return garbage tokens

Fix

Remove three with_layout_constraint call sites in tpu_inference/layers/common/process_weights/moe_weights.py (unquantized weight processing + fused-moe weight reshape) and drop the now-unused import. The MoE kernel can use the natural C-order layout the auto-layout pass picks; no constraint is needed on jax 0.10. All other MoE quantization paths route through the same process_moe_weights, so they all benefit.

Verification (from #2665)

  • Before: gemma-4-26B-A4B-it returns garbage tokens on jax 0.10.0.
  • After: same model returns 4 (2+2), 56 (7x8), coherent text on a haiku prompt.

Perf comparison on the v0.21.0 nightly (builds 17833 baseline vs 18033 with the fix):

Model Output throughput P99 TPOT
Qwen3-Coder-480B-A35B-Instruct (FP8) −0.12% (noise) −0.16% (noise)
Qwen3-30B-A3B (the model #778 cited for the constraint's 10% speedup) +1.27% −4.36%

i.e. on jax 0.10, auto-layout picks an equally-good (or better) layout for the MoE GMM kernel; the manual constraint is no longer earning its keep.

Related

Test plan

  • CI runs the full PR pipeline on JAX 0.10.0
  • gemma-4-26B-A4B-it accuracy back to ≥ baseline (was 0.0, expect ~0.41)
  • No regression on Qwen3 MoE workloads (BCM verified on nightly 18033)

@github-actions
Copy link
Copy Markdown

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a Github issue, please include a link, e.g.,:
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@QiliangCui2023 QiliangCui2023 added the ready ONLY add when PR is ready to merge/full CI is needed label May 20, 2026
@kyuyeunk
Copy link
Copy Markdown
Collaborator

@guowei-dev: fyi, after this pr has been merged, the reverted sparse core kernel pr can be merged again: #2634

Copy link
Copy Markdown
Collaborator

@lk-chen lk-chen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one nit. Thanks for chasing this!

Comment thread tpu_inference/kernels/sparse_core/gather_reduce.py
This reverts commit 60cd862.

Signed-off-by: Qiliang Cui <derrhein@gmail.com>
…ax 0.10

The with_layout_constraint(weight, Layout(...)) calls in process_moe_weights
were a workaround for a jax-0.9.x "must have valid byte strides" error. On
jax 0.10 the propagation path now reads layout_constraint_p in
get_out_layouts_via_propagation and actually pins the requested layout to
the jit output — which is the WRONG layout for the MoE kernel's weight reads
and silently miscomputes: gemma-4-26B-A4B-it GSM8K accuracy drops from 0.41
to 0.0, and individual prompts return garbage.

The MoE kernel can use the natural C-order layout that the auto-layout pass
picks; no constraint is needed on jax 0.10. Removing the three call sites
(unquantized weight processing + fused-moe weight reshape) fixes gemma-4
accuracy without reverting jax. All other MoE quantization paths route
through the same process_moe_weights, so they all benefit from the fix.

Verified:
- Before: gemma-4-26B-A4B-it returns garbage tokens on jax 0.10.0.
- After:  same model returns "4" (2+2), "56" (7x8), coherent text on a
  haiku prompt.

Signed-off-by: Qiliang Cui <derrhein@gmail.com>
@QiliangCui2023 QiliangCui2023 force-pushed the cuiq-jax010-with-moe-fix branch from 0f619cb to a801dde Compare May 20, 2026 23:49
@QiliangCui QiliangCui enabled auto-merge (squash) May 21, 2026 01:44
@QiliangCui QiliangCui merged commit acd00f9 into main May 21, 2026
62 checks passed
ylangtsou pushed a commit to ylangtsou/tpu-inference that referenced this pull request May 21, 2026
Signed-off-by: Qiliang Cui <derrhein@gmail.com>
Co-authored-by: Qiliang Cui <derrhein@gmail.com>
jyj0w0 pushed a commit that referenced this pull request May 25, 2026
Signed-off-by: Qiliang Cui <derrhein@gmail.com>
Co-authored-by: Qiliang Cui <derrhein@gmail.com>
patrickji2014 pushed a commit that referenced this pull request May 27, 2026
Signed-off-by: Qiliang Cui <derrhein@gmail.com>
Co-authored-by: Qiliang Cui <derrhein@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants