File tree Expand file tree Collapse file tree
mlx/backend/rocm/quantized Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -94,11 +94,17 @@ __device__ __forceinline__ void load_weight_vec(
9494 out[0 ] = v.x ;
9595 out[1 ] = v.y ;
9696 } else if constexpr (PPT == 4 ) {
97- uint4 v = *reinterpret_cast <const uint4*>(ptr);
98- out[0 ] = v.x ;
99- out[1 ] = v.y ;
100- out[2 ] = v.z ;
101- out[3 ] = v.w ;
97+ // Two uint2 loads instead of one uint4. The single-uint4 load
98+ // (global_load_b128) miscomputes in the 8-bit affine QMV/gather paths
99+ // (root cause: HIP_vector_type<unsigned int,4> codegen on RDNA 3.5 with
100+ // hipcc 7.13 / LLVM 23). Two paired global_load_b64 ops yield the same
101+ // throughput on RDNA 3.5 without the miscompile.
102+ uint2 v0 = *reinterpret_cast <const uint2*>(ptr);
103+ uint2 v1 = *reinterpret_cast <const uint2*>(ptr + 2 );
104+ out[0 ] = v0.x ;
105+ out[1 ] = v0.y ;
106+ out[2 ] = v1.x ;
107+ out[3 ] = v1.y ;
102108 } else {
103109 #pragma unroll
104110 for (int p = 0 ; p < PPT; p++) {
You can’t perform that action at this time.
0 commit comments