Skip to content

Commit e15fcef

Browse files
committed
ROCm: fix 8-bit affine QMV miscompile from uint4 weight load
The qmv_tiled_kernel and gather_qmv_tiled_kernel use load_weight_vec to fetch packed weights in one transaction. For 4-bit (PPT=2) the function issues a uint2 load (global_load_b64) and produces correct output. For 8-bit (PPT=4) it was issuing a single uint4 load (global_load_b128) and that path miscompiles on RDNA 3.5 with hipcc 7.13 / LLVM 23: the dot products come out wrong, even though the load is naturally 16-byte aligned and the indices, scale/bias lookups, and reductions are otherwise identical to the 4-bit path that works. Replace the single uint4 load with two paired uint2 loads. Both forms issue 128 bits of weight traffic per lane per K-step, so there is no throughput regression on RDNA 3.5 (one CU still issues both b64s in a single cycle), but the codegen path is the same one that the 4-bit kernel uses and already validated. Repro before this change (gfx1151, hipcc 7.13): Qwen3-Coder-Next-4bit ("model_type": "qwen3_next" + 8-bit overrides for every mlp.gate / shared_expert_gate, default 4-bit elsewhere) decoded gibberish from the first generated token because the MoE router gate output was wrong. Setting MLX_ROCM_QMV_NO_TILED=1 (which routes through the qmv_warp_shared scalar path) restored correct output; setting MLX_ROCM_QMM_DEQUANT_GEMM=0 did not (which ruled out the dequant+rocBLAS path). Bisecting via per-bitwidth dispatch isolated the bug to qmv_tiled_kernel's 8-bit instantiation; running the kernel with scalar w_row[w_offset + p] loads instead of load_weight_vec also restored correct output, which pinned the miscompile to the uint4 path inside load_weight_vec. Verified after the fix on gfx1151 (Strix Halo, RDNA 3.5, ROCm 7.13): Qwen3-0.6B-4bit -> "2 + 2 = 4." Qwen3-1.7B-4bit -> "2 + 2 = 4." Qwen3-4B-4bit -> "The sum of 2 and 2 is 4..." Qwen3-8B-4bit -> "2 + 2 = 4." Qwen3.5-35B-A3B-4bit -> "2 plus two equals **4**." Qwen3-Coder-Next-4bit -> "<tool_call> 2 + 2 = 4" (was gibberish) Affects QuantizedMatmul / GatherQMM for any 8-bit affine quantization on the tiled QMV fast path, which is exercised by every MoE model that quantizes its router gate at 8 bits (and any future 8-bit-only model that lands on this path).
1 parent 526dbbd commit e15fcef

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

mlx/backend/rocm/quantized/qdequant.hpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,17 @@ __device__ __forceinline__ void load_weight_vec(
9494
out[0] = v.x;
9595
out[1] = v.y;
9696
} else if constexpr (PPT == 4) {
97-
uint4 v = *reinterpret_cast<const uint4*>(ptr);
98-
out[0] = v.x;
99-
out[1] = v.y;
100-
out[2] = v.z;
101-
out[3] = v.w;
97+
// Two uint2 loads instead of one uint4. The single-uint4 load
98+
// (global_load_b128) miscomputes in the 8-bit affine QMV/gather paths
99+
// (root cause: HIP_vector_type<unsigned int,4> codegen on RDNA 3.5 with
100+
// hipcc 7.13 / LLVM 23). Two paired global_load_b64 ops yield the same
101+
// throughput on RDNA 3.5 without the miscompile.
102+
uint2 v0 = *reinterpret_cast<const uint2*>(ptr);
103+
uint2 v1 = *reinterpret_cast<const uint2*>(ptr + 2);
104+
out[0] = v0.x;
105+
out[1] = v0.y;
106+
out[2] = v1.x;
107+
out[3] = v1.y;
102108
} else {
103109
#pragma unroll
104110
for (int p = 0; p < PPT; p++) {

0 commit comments

Comments
 (0)