Commit e15fcef
committed
ROCm: fix 8-bit affine QMV miscompile from uint4 weight load
The qmv_tiled_kernel and gather_qmv_tiled_kernel use load_weight_vec to
fetch packed weights in one transaction. For 4-bit (PPT=2) the function
issues a uint2 load (global_load_b64) and produces correct output. For
8-bit (PPT=4) it was issuing a single uint4 load (global_load_b128) and
that path miscompiles on RDNA 3.5 with hipcc 7.13 / LLVM 23: the dot
products come out wrong, even though the load is naturally 16-byte
aligned and the indices, scale/bias lookups, and reductions are
otherwise identical to the 4-bit path that works.
Replace the single uint4 load with two paired uint2 loads. Both forms
issue 128 bits of weight traffic per lane per K-step, so there is no
throughput regression on RDNA 3.5 (one CU still issues both b64s in a
single cycle), but the codegen path is the same one that the 4-bit
kernel uses and already validated.
Repro before this change (gfx1151, hipcc 7.13):
Qwen3-Coder-Next-4bit ("model_type": "qwen3_next" + 8-bit overrides
for every mlp.gate / shared_expert_gate, default 4-bit elsewhere)
decoded gibberish from the first generated token because the MoE
router gate output was wrong. Setting MLX_ROCM_QMV_NO_TILED=1 (which
routes through the qmv_warp_shared scalar path) restored correct
output; setting MLX_ROCM_QMM_DEQUANT_GEMM=0 did not (which ruled out
the dequant+rocBLAS path). Bisecting via per-bitwidth dispatch
isolated the bug to qmv_tiled_kernel's 8-bit instantiation; running
the kernel with scalar w_row[w_offset + p] loads instead of
load_weight_vec also restored correct output, which pinned the
miscompile to the uint4 path inside load_weight_vec.
Verified after the fix on gfx1151 (Strix Halo, RDNA 3.5, ROCm 7.13):
Qwen3-0.6B-4bit -> "2 + 2 = 4."
Qwen3-1.7B-4bit -> "2 + 2 = 4."
Qwen3-4B-4bit -> "The sum of 2 and 2 is 4..."
Qwen3-8B-4bit -> "2 + 2 = 4."
Qwen3.5-35B-A3B-4bit -> "2 plus two equals **4**."
Qwen3-Coder-Next-4bit -> "<tool_call> 2 + 2 = 4" (was gibberish)
Affects QuantizedMatmul / GatherQMM for any 8-bit affine quantization
on the tiled QMV fast path, which is exercised by every MoE model that
quantizes its router gate at 8 bits (and any future 8-bit-only model
that lands on this path).1 parent 526dbbd commit e15fcef
1 file changed
Lines changed: 11 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
94 | 94 | | |
95 | 95 | | |
96 | 96 | | |
97 | | - | |
98 | | - | |
99 | | - | |
100 | | - | |
101 | | - | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
102 | 108 | | |
103 | 109 | | |
104 | 110 | | |
| |||
0 commit comments