Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlx/backend/metal/kernels/fp_quantized.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ METAL_FUNC void fp_qmv_impl(
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
const device auto* sl = scales + row * in_vec_size_g;

uint8_t s = sl[0];
U s = dequantize_scale<U, group_size>(sl[0]);
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s);
}

Expand Down
41 changes: 41 additions & 0 deletions python/tests/test_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,47 @@ def test_fp_qmv(self):
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)

def test_fp_qmv_small_non_multiples(self):
# Regression test for the fp_qmv_impl out_vec_size < 8 branch, whose
# full-block loop loaded the fp8 scale as a raw byte instead of decoding
# it through dequantize_scale (see fp_quantized.h). The bug only bites
# fp-quantized matvec when the output dim is < 8 AND K is large enough to
# run at least one full block (K > block_size), so it slipped past both
# test_fp_qmv (output dim >= 8) and test_qmv_small_non_multiples (K = 32,
# below block_size, exercising only the correct remainder loop).
# K = 512 forces full blocks; N in {1..7} hits the < 8 branch.
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
K = 512
for mode, group_size, bits in [
("mxfp4", 32, 4),
("mxfp8", 32, 8),
("nvfp4", 16, 4),
]:
for M in [1, 2]:
for N in [1, 2, 3, 5, 7]:
with self.subTest(M=M, K=K, N=N, mode=mode):
x = mx.random.normal(shape=(M, K), key=k1) / K**0.5
w = mx.random.normal(shape=(N, K), key=k2) / K**0.5
w_q, scales = mx.quantize(
w, group_size=group_size, bits=bits, mode=mode
)
w_hat = mx.dequantize(
w_q, scales, group_size=group_size, bits=bits, mode=mode
)
y_q = mx.quantized_matmul(
x,
w_q,
scales,
transpose=True,
group_size=group_size,
bits=bits,
mode=mode,
)
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)

def test_qmv_wide(self):
# M in [2, vector_limit) routes to qmv_wide -- except K in {64, 128}
# with power-of-2 bits, which stays on qmv_quad. Check both paths
Expand Down