Skip to content

Commit 9e768e4

Browse files
authored
Merge pull request #7 from NripeshN/geramy/fix-rocm-qmv-8bit-uint4-miscompile
ROCm: fix 8-bit affine QMV miscompile from uint4 weight load
2 parents 767b0aa + e15fcef commit 9e768e4

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

mlx/backend/rocm/quantized/qdequant.hpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,17 @@ __device__ __forceinline__ void load_weight_vec(
9494
out[0] = v.x;
9595
out[1] = v.y;
9696
} else if constexpr (PPT == 4) {
97-
uint4 v = *reinterpret_cast<const uint4*>(ptr);
98-
out[0] = v.x;
99-
out[1] = v.y;
100-
out[2] = v.z;
101-
out[3] = v.w;
97+
// Two uint2 loads instead of one uint4. The single-uint4 load
98+
// (global_load_b128) miscomputes in the 8-bit affine QMV/gather paths
99+
// (root cause: HIP_vector_type<unsigned int,4> codegen on RDNA 3.5 with
100+
// hipcc 7.13 / LLVM 23). Two paired global_load_b64 ops yield the same
101+
// throughput on RDNA 3.5 without the miscompile.
102+
uint2 v0 = *reinterpret_cast<const uint2*>(ptr);
103+
uint2 v1 = *reinterpret_cast<const uint2*>(ptr + 2);
104+
out[0] = v0.x;
105+
out[1] = v0.y;
106+
out[2] = v1.x;
107+
out[3] = v1.y;
102108
} else {
103109
#pragma unroll
104110
for (int p = 0; p < PPT; p++) {

0 commit comments

Comments
 (0)