Skip to content

Commit 516b5a1

Browse files
authored
Merge pull request #6 from soloish90/fix/rocm-qmv-tiled-8bit
Fix: broken/missing 8-bit inference in tiled QMV path
2 parents d999ca6 + 39fac95 commit 516b5a1

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

mlx/backend/rocm/quantized/qmm.hip

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2959,12 +2959,20 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
29592959
if (group_size_ == 32) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 4, 32); }
29602960
else if (group_size_ == 64) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 4, 64); }
29612961
else if (group_size_ == 128) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 4, 128); }
2962+
} else if (bits_ == 8) {
2963+
if (group_size_ == 32) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 8, 32); }
2964+
else if (group_size_ == 64) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 8, 64); }
2965+
else if (group_size_ == 128) { LAUNCH_TILED(hip_bfloat16, hip_bfloat16, 8, 128); }
29622966
}
29632967
} else if (x.dtype() == float16) {
29642968
if (bits_ == 4) {
29652969
if (group_size_ == 32) { LAUNCH_TILED(__half, __half, 4, 32); }
29662970
else if (group_size_ == 64) { LAUNCH_TILED(__half, __half, 4, 64); }
29672971
else if (group_size_ == 128) { LAUNCH_TILED(__half, __half, 4, 128); }
2972+
} else if (bits_ == 8) {
2973+
if (group_size_ == 32) { LAUNCH_TILED(__half, __half, 8, 32); }
2974+
else if (group_size_ == 64) { LAUNCH_TILED(__half, __half, 8, 64); }
2975+
else if (group_size_ == 128) { LAUNCH_TILED(__half, __half, 8, 128); }
29682976
}
29692977
}
29702978
#undef LAUNCH_TILED

0 commit comments

Comments
 (0)