ROCm: fix 8-bit affine QMV miscompile from uint4 weight load#7
Merged
Merged
Conversation
The qmv_tiled_kernel and gather_qmv_tiled_kernel use load_weight_vec to
fetch packed weights in one transaction. For 4-bit (PPT=2) the function
issues a uint2 load (global_load_b64) and produces correct output. For
8-bit (PPT=4) it was issuing a single uint4 load (global_load_b128) and
that path miscompiles on RDNA 3.5 with hipcc 7.13 / LLVM 23: the dot
products come out wrong, even though the load is naturally 16-byte
aligned and the indices, scale/bias lookups, and reductions are
otherwise identical to the 4-bit path that works.
Replace the single uint4 load with two paired uint2 loads. Both forms
issue 128 bits of weight traffic per lane per K-step, so there is no
throughput regression on RDNA 3.5 (one CU still issues both b64s in a
single cycle), but the codegen path is the same one that the 4-bit
kernel uses and already validated.
Repro before this change (gfx1151, hipcc 7.13):
Qwen3-Coder-Next-4bit ("model_type": "qwen3_next" + 8-bit overrides
for every mlp.gate / shared_expert_gate, default 4-bit elsewhere)
decoded gibberish from the first generated token because the MoE
router gate output was wrong. Setting MLX_ROCM_QMV_NO_TILED=1 (which
routes through the qmv_warp_shared scalar path) restored correct
output; setting MLX_ROCM_QMM_DEQUANT_GEMM=0 did not (which ruled out
the dequant+rocBLAS path). Bisecting via per-bitwidth dispatch
isolated the bug to qmv_tiled_kernel's 8-bit instantiation; running
the kernel with scalar w_row[w_offset + p] loads instead of
load_weight_vec also restored correct output, which pinned the
miscompile to the uint4 path inside load_weight_vec.
Verified after the fix on gfx1151 (Strix Halo, RDNA 3.5, ROCm 7.13):
Qwen3-0.6B-4bit -> "2 + 2 = 4."
Qwen3-1.7B-4bit -> "2 + 2 = 4."
Qwen3-4B-4bit -> "The sum of 2 and 2 is 4..."
Qwen3-8B-4bit -> "2 + 2 = 4."
Qwen3.5-35B-A3B-4bit -> "2 plus two equals **4**."
Qwen3-Coder-Next-4bit -> "<tool_call> 2 + 2 = 4" (was gibberish)
Affects QuantizedMatmul / GatherQMM for any 8-bit affine quantization
on the tiled QMV fast path, which is exercised by every MoE model that
quantizes its router gate at 8 bits (and any future 8-bit-only model
that lands on this path).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
load_weight_vec<BITS>inmlx/backend/rocm/quantized/qdequant.hppwas issuing a singleuint4load (→global_load_b128) for the 8-bit affine path (PPT=4). That codegen path miscompiles on RDNA 3.5 with hipcc 7.13 / LLVM 23, even though the load is naturally 16-byte aligned and the indices, scale/bias lookups, and reductions match the working 4-bit path. The 4-bit path (PPT=2,uint2/global_load_b64) is unaffected.Replace the single
uint4load with two paireduint2loads. Same 128 bits of weight traffic, same throughput on RDNA 3.5 (one CU still issues both b64s in a single cycle), but the codegen path is the same one that the 4-bit kernel already validates.Bisection
On gfx1151 (Strix Halo) / ROCm 7.13:
quantizationconfig has default 4-bit affine but 8-bit overrides for everymlp.gate/shared_expert_gate(96 layers × 2 overrides) — the MoE router gate. Because the router output was wrong, every downstream expert dispatch was wrong.MLX_ROCM_QMM_DEQUANT_GEMM=0did NOT fix it → ruled out the dequant+rocBLAS GEMM path.MLX_ROCM_QMV_NO_TILED=1(route 8-bit QMV throughqmv_warp_shared_kernelinstead ofqmv_tiled_kernel) DID fix it → pinned toqmv_tiled_kernel.load_weight_vec<BITS>(w_row + w_offset, w_local)in the kernel's fast-path branch with the scalar bounds-checked loopw_local[p] = w_row[w_offset + p]also fixed it → pinned toload_weight_vec.load_weight_veconly thePPT == 4branch (uint4) was different from the workingPPT == 2branch (uint2). Replacing the single uint4 load with two uint2 loads → fixed.Verified on gfx1151 / ROCm 7.13 (Strix Halo)
Impact
Affects
QuantizedMatmul/GatherQMMfor any 8-bit affine quantization on the tiled QMV fast path. Exercised by every MoE model that quantizes its router gate at 8 bits (Qwen3-Coder-Next, and likely other recent Qwen3-Next-family / hybrid-attention MoE checkpoints), plus any future 8-bit-only model that lands on this path.Test plan