Skip to content

ROCm: fix 8-bit affine QMV miscompile from uint4 weight load#7

Merged
Geramy merged 1 commit into
rocm-supportfrom
geramy/fix-rocm-qmv-8bit-uint4-miscompile
May 20, 2026
Merged

ROCm: fix 8-bit affine QMV miscompile from uint4 weight load#7
Geramy merged 1 commit into
rocm-supportfrom
geramy/fix-rocm-qmv-8bit-uint4-miscompile

Conversation

@Geramy

@Geramy Geramy commented May 20, 2026

Copy link
Copy Markdown
Collaborator

Summary

load_weight_vec<BITS> in mlx/backend/rocm/quantized/qdequant.hpp was issuing a single uint4 load (→ global_load_b128) for the 8-bit affine path (PPT=4). That codegen path miscompiles on RDNA 3.5 with hipcc 7.13 / LLVM 23, even though the load is naturally 16-byte aligned and the indices, scale/bias lookups, and reductions match the working 4-bit path. The 4-bit path (PPT=2, uint2 / global_load_b64) is unaffected.

Replace the single uint4 load with two paired uint2 loads. Same 128 bits of weight traffic, same throughput on RDNA 3.5 (one CU still issues both b64s in a single cycle), but the codegen path is the same one that the 4-bit kernel already validates.

Bisection

On gfx1151 (Strix Halo) / ROCm 7.13:

  • Qwen3-Coder-Next-4bit decoded gibberish from the first generated token. The model's quantization config has default 4-bit affine but 8-bit overrides for every mlp.gate / shared_expert_gate (96 layers × 2 overrides) — the MoE router gate. Because the router output was wrong, every downstream expert dispatch was wrong.
  • MLX_ROCM_QMM_DEQUANT_GEMM=0 did NOT fix it → ruled out the dequant+rocBLAS GEMM path.
  • MLX_ROCM_QMV_NO_TILED=1 (route 8-bit QMV through qmv_warp_shared_kernel instead of qmv_tiled_kernel) DID fix it → pinned to qmv_tiled_kernel.
  • Replacing load_weight_vec<BITS>(w_row + w_offset, w_local) in the kernel's fast-path branch with the scalar bounds-checked loop w_local[p] = w_row[w_offset + p] also fixed it → pinned to load_weight_vec.
  • Inside load_weight_vec only the PPT == 4 branch (uint4) was different from the working PPT == 2 branch (uint2). Replacing the single uint4 load with two uint2 loads → fixed.

Verified on gfx1151 / ROCm 7.13 (Strix Halo)

Model Output Status
Qwen3-0.6B-4bit "2 + 2 = 4."
Qwen3-1.7B-4bit "2 + 2 = 4."
Qwen3-4B-4bit "The sum of 2 and 2 is 4..."
Qwen3-8B-4bit "2 + 2 = 4."
Qwen3.5-35B-A3B-4bit "2 plus two equals 4."
Qwen3-Coder-Next-4bit "<tool_call> 2 + 2 = 4" ✅ (was gibberish)

Impact

Affects QuantizedMatmul / GatherQMM for any 8-bit affine quantization on the tiled QMV fast path. Exercised by every MoE model that quantizes its router gate at 8 bits (Qwen3-Coder-Next, and likely other recent Qwen3-Next-family / hybrid-attention MoE checkpoints), plus any future 8-bit-only model that lands on this path.

Test plan

  • Qwen3 dense 0.6B / 1.7B / 4B / 8B (4-bit) — verified no regression
  • Qwen3.5-35B-A3B-4bit (4-bit MoE) — verified no regression
  • Qwen3-Coder-Next-4bit (4-bit MoE with 8-bit router gates) — verified fix
  • CDNA / wave64 path — not retested by me; the code change is purely a load-instruction substitution so behaviour should be identical there, but worth a sanity run before merge

The qmv_tiled_kernel and gather_qmv_tiled_kernel use load_weight_vec to
fetch packed weights in one transaction. For 4-bit (PPT=2) the function
issues a uint2 load (global_load_b64) and produces correct output. For
8-bit (PPT=4) it was issuing a single uint4 load (global_load_b128) and
that path miscompiles on RDNA 3.5 with hipcc 7.13 / LLVM 23: the dot
products come out wrong, even though the load is naturally 16-byte
aligned and the indices, scale/bias lookups, and reductions are
otherwise identical to the 4-bit path that works.

Replace the single uint4 load with two paired uint2 loads. Both forms
issue 128 bits of weight traffic per lane per K-step, so there is no
throughput regression on RDNA 3.5 (one CU still issues both b64s in a
single cycle), but the codegen path is the same one that the 4-bit
kernel uses and already validated.

Repro before this change (gfx1151, hipcc 7.13):
  Qwen3-Coder-Next-4bit ("model_type": "qwen3_next" + 8-bit overrides
  for every mlp.gate / shared_expert_gate, default 4-bit elsewhere)
  decoded gibberish from the first generated token because the MoE
  router gate output was wrong. Setting MLX_ROCM_QMV_NO_TILED=1 (which
  routes through the qmv_warp_shared scalar path) restored correct
  output; setting MLX_ROCM_QMM_DEQUANT_GEMM=0 did not (which ruled out
  the dequant+rocBLAS path). Bisecting via per-bitwidth dispatch
  isolated the bug to qmv_tiled_kernel's 8-bit instantiation; running
  the kernel with scalar w_row[w_offset + p] loads instead of
  load_weight_vec also restored correct output, which pinned the
  miscompile to the uint4 path inside load_weight_vec.

Verified after the fix on gfx1151 (Strix Halo, RDNA 3.5, ROCm 7.13):
  Qwen3-0.6B-4bit     -> "2 + 2 = 4."
  Qwen3-1.7B-4bit     -> "2 + 2 = 4."
  Qwen3-4B-4bit       -> "The sum of 2 and 2 is 4..."
  Qwen3-8B-4bit       -> "2 + 2 = 4."
  Qwen3.5-35B-A3B-4bit -> "2 plus two equals **4**."
  Qwen3-Coder-Next-4bit -> "<tool_call> 2 + 2 = 4"  (was gibberish)

Affects QuantizedMatmul / GatherQMM for any 8-bit affine quantization
on the tiled QMV fast path, which is exercised by every MoE model that
quantizes its router gate at 8 bits (and any future 8-bit-only model
that lands on this path).
@Geramy Geramy merged commit 9e768e4 into rocm-support May 20, 2026
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant