Fix KV head shape mismatch when TP size exceeds num_kv_heads#3426
Fix KV head shape mismatch when TP size exceeds num_kv_heads#3426copybara-service[bot] merged 1 commit intomainfrom
Conversation
06d458d to
e1d2e70
Compare
NicoGrande
left a comment
There was a problem hiding this comment.
Could you test on a VM with 4 Jax devices (maybe v5p-8?). I tried on my v6e-4 and saw nonsense outputs so perhaps this is currently breaking the codepath where no changes are necessary.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
c64953d to
3459d92
Compare
xuefgu
left a comment
There was a problem hiding this comment.
Which Orbax version is this PR tested against?
At what scale is this PR tested? Any practical runs on GKE using Pathways?
04b6506 to
1d2eefa
Compare
aebda8a to
a2b9861
Compare
a2b9861 to
79348a3
Compare
3dbfa18 to
b1abcf7
Compare
b1abcf7 to
831d5c9
Compare
A run for Qwen3-30B-A3B on 1 slice of v5p-128 https://cloudlogging.app.goo.gl/Rna8hP21D5XKAHoF9 More runs are in this bug: b/498435735 |
xuefgu
left a comment
There was a problem hiding this comment.
For posterity:
(1) Fast follow - logical axis rules
(2) Currently it works for v5p but not for v7x yet, although we are not sure whether that's because of our stack. See b/496608632.
Description
This PR was done in collaboration with @NicoGrande
Problem: When serving models like Qwen3-30B-A3B (4 KV heads) with TP=8, adapter.py pads base_num_kv_heads to 8 to match TP size. This caused Orbax to reject checkpoint restore because the stored shape (seq, 4, 128) didn't match the model's padded shape (seq, 8, 128).
Fix:
[h0,h1,...,h0,h1,...].
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
CI Tests
vllm_decode for Qwen3-30B-A3B with 4 kv heads on a v6e-8
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.