Skip to content

Commit ce93101

Browse files
jax-0n-gitclaude
andcommitted
Fix fp_qmv_impl small-output-dim branch using raw fp8 scale byte
The out_vec_size < 8 branch's full-block loop loaded the fp8 scale as a raw byte and passed it to qdot without dequantize_scale, so fp-quantized (mxfp4/mxfp8/nvfp4) matvec with output dim < 8 multiplied by the raw e8m0/e4m3 byte instead of the decoded scale (gross error, ~1e2-1e4 relative). Every other fp scale-load site decodes it, including the remainder loop of this same branch. Add test_fp_qmv_small_non_multiples covering the fp modes at output dim < 8 with K large enough to run the full-block loop (the existing fp tests use output dim >= 8 or K below block_size). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent b410f6c commit ce93101

2 files changed

Lines changed: 42 additions & 1 deletion

File tree

mlx/backend/metal/kernels/fp_quantized.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ METAL_FUNC void fp_qmv_impl(
440440
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
441441
const device auto* sl = scales + row * in_vec_size_g;
442442

443-
uint8_t s = sl[0];
443+
U s = dequantize_scale<U, group_size>(sl[0]);
444444
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s);
445445
}
446446

python/tests/test_quantized.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,47 @@ def test_fp_qmv(self):
482482
self.assertEqual(y_q.shape, y_hat.shape)
483483
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
484484

485+
def test_fp_qmv_small_non_multiples(self):
486+
# Regression test for the fp_qmv_impl out_vec_size < 8 branch, whose
487+
# full-block loop loaded the fp8 scale as a raw byte instead of decoding
488+
# it through dequantize_scale (see fp_quantized.h). The bug only bites
489+
# fp-quantized matvec when the output dim is < 8 AND K is large enough to
490+
# run at least one full block (K > block_size), so it slipped past both
491+
# test_fp_qmv (output dim >= 8) and test_qmv_small_non_multiples (K = 32,
492+
# below block_size, exercising only the correct remainder loop).
493+
# K = 512 forces full blocks; N in {1..7} hits the < 8 branch.
494+
key = mx.random.key(0)
495+
k1, k2 = mx.random.split(key)
496+
K = 512
497+
for mode, group_size, bits in [
498+
("mxfp4", 32, 4),
499+
("mxfp8", 32, 8),
500+
("nvfp4", 16, 4),
501+
]:
502+
for M in [1, 2]:
503+
for N in [1, 2, 3, 5, 7]:
504+
with self.subTest(M=M, K=K, N=N, mode=mode):
505+
x = mx.random.normal(shape=(M, K), key=k1) / K**0.5
506+
w = mx.random.normal(shape=(N, K), key=k2) / K**0.5
507+
w_q, scales = mx.quantize(
508+
w, group_size=group_size, bits=bits, mode=mode
509+
)
510+
w_hat = mx.dequantize(
511+
w_q, scales, group_size=group_size, bits=bits, mode=mode
512+
)
513+
y_q = mx.quantized_matmul(
514+
x,
515+
w_q,
516+
scales,
517+
transpose=True,
518+
group_size=group_size,
519+
bits=bits,
520+
mode=mode,
521+
)
522+
y_hat = x @ mx.swapaxes(w_hat, -1, -2)
523+
self.assertEqual(y_q.shape, y_hat.shape)
524+
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
525+
485526
def test_qvm(self):
486527
key = mx.random.key(0)
487528
k1, k2 = mx.random.split(key)

0 commit comments

Comments
 (0)