Skip to content

Commit 554d4b5

Browse files
committed
Add W4A8 INT8 activation kernels for batched MoE prefill
INT8 tensor core variants of the batched MoE GEMM kernels that dynamically quantize bf16 activations to INT8 per-row per-tile and dequantize INT4 weights directly to INT8 (skipping bf16 conversion). Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine similarity vs bf16 baseline. Co-authored-by: Claude <noreplyanthropic.com> ghstack-source-id: 96200a1 Pull Request resolved: #19187
1 parent cb4e5ae commit 554d4b5

4 files changed

Lines changed: 616 additions & 4 deletions

File tree

backends/cuda/tests/test_fused_moe.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from executorch.backends.cuda.triton.kernels.fused_moe import (
3232
fused_moe as triton_fused_moe,
3333
fused_moe_batched as triton_fused_moe_batched,
34+
fused_moe_batched_gemm_int8 as triton_fused_moe_batched_int8,
3435
moe_align_block_size,
3536
)
3637
from executorch.exir import (
@@ -212,6 +213,11 @@ def _run_cpp_runner(runner_path, pte_path, ptd_path, input_files, output_base):
212213

213214

214215
class TestFusedMoE(unittest.TestCase):
216+
# TODO: migrate from manual max_abs/max_ref relative checks to
217+
# torch.allclose(atol=, rtol=). Current tests use per-tensor-max relative
218+
# error which is looser than per-element allclose — need to calibrate atol
219+
# for INT4 quantization noise floor across random weight magnitudes.
220+
215221
def setUp(self):
216222
if not torch.cuda.is_available():
217223
self.skipTest("CUDA is not available")
@@ -487,6 +493,152 @@ def test_e2e_cpp_runner(self):
487493
)
488494

489495

496+
class TestFusedMoEBatchedInt8(unittest.TestCase):
497+
"""Correctness tests for the INT8 dynamic-activation batched MoE kernel."""
498+
499+
INT8_TEST_CONFIGS = [
500+
(42, 8, 64, 32, 4, 2, 32, "8tok_small"),
501+
(7, 16, 64, 32, 8, 4, 32, "16tok_8exp_top4"),
502+
(13, 32, 128, 64, 8, 2, 64, "32tok_gs64"),
503+
(55, 64, 64, 32, 4, 2, 32, "64tok"),
504+
(99, 128, 128, 64, 8, 2, 32, "128tok"),
505+
(0, 256, 128, 64, 8, 2, 32, "256tok"),
506+
]
507+
508+
def test_int8_correctness(self):
509+
"""INT8 batched kernel matches reference across M values."""
510+
for (
511+
seed,
512+
M,
513+
hidden,
514+
intermediate,
515+
num_experts,
516+
top_k,
517+
gs,
518+
desc,
519+
) in self.INT8_TEST_CONFIGS:
520+
with self.subTest(desc=desc):
521+
torch.manual_seed(seed)
522+
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
523+
w1_weight = torch.randn(
524+
num_experts,
525+
2 * intermediate,
526+
hidden,
527+
dtype=torch.bfloat16,
528+
device="cuda",
529+
)
530+
w2_weight = torch.randn(
531+
num_experts,
532+
hidden,
533+
intermediate,
534+
dtype=torch.bfloat16,
535+
device="cuda",
536+
)
537+
w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs)
538+
w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs)
539+
w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda()
540+
541+
scores = torch.randn(M, num_experts, device="cuda")
542+
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
543+
topk_weights = topk_weights.softmax(dim=-1).float()
544+
545+
out_int8 = triton_fused_moe_batched_int8(
546+
x,
547+
w1,
548+
w1s,
549+
w2,
550+
w2s,
551+
topk_weights,
552+
topk_ids,
553+
top_k,
554+
num_experts,
555+
gs,
556+
)
557+
558+
w1_dq = _dequantize_int4(w1.cpu(), w1s.cpu(), gs).cuda()
559+
w2_dq = _dequantize_int4(w2.cpu(), w2s.cpu(), gs).cuda()
560+
ref = _reference_moe(x, w1_dq, w2_dq, topk_weights, topk_ids, top_k)
561+
562+
diff = (out_int8.float() - ref.float()).abs().max().item()
563+
rel = diff / (ref.float().abs().max().item() + 1e-10)
564+
self.assertLess(
565+
rel,
566+
0.10,
567+
f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})",
568+
)
569+
570+
def test_int8_matches_bf16_batched(self):
571+
"""INT8 batched output is close to BF16 batched output."""
572+
for (
573+
seed,
574+
M,
575+
hidden,
576+
intermediate,
577+
num_experts,
578+
top_k,
579+
gs,
580+
desc,
581+
) in self.INT8_TEST_CONFIGS:
582+
with self.subTest(desc=desc):
583+
torch.manual_seed(seed)
584+
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
585+
w1_weight = torch.randn(
586+
num_experts,
587+
2 * intermediate,
588+
hidden,
589+
dtype=torch.bfloat16,
590+
device="cuda",
591+
)
592+
w2_weight = torch.randn(
593+
num_experts,
594+
hidden,
595+
intermediate,
596+
dtype=torch.bfloat16,
597+
device="cuda",
598+
)
599+
w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs)
600+
w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs)
601+
w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda()
602+
603+
scores = torch.randn(M, num_experts, device="cuda")
604+
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
605+
topk_weights = topk_weights.softmax(dim=-1).float()
606+
607+
out_bf16 = triton_fused_moe_batched(
608+
x,
609+
w1,
610+
w1s,
611+
w2,
612+
w2s,
613+
topk_weights,
614+
topk_ids,
615+
top_k,
616+
num_experts,
617+
gs,
618+
)
619+
620+
out_int8 = triton_fused_moe_batched_int8(
621+
x,
622+
w1,
623+
w1s,
624+
w2,
625+
w2s,
626+
topk_weights,
627+
topk_ids,
628+
top_k,
629+
num_experts,
630+
gs,
631+
)
632+
633+
diff = (out_int8.float() - out_bf16.float()).abs().max().item()
634+
rel = diff / (out_bf16.float().abs().max().item() + 1e-10)
635+
self.assertLess(
636+
rel,
637+
0.15,
638+
f"{desc}: int8 vs bf16 relative diff {rel:.4f} (abs {diff:.6f})",
639+
)
640+
641+
490642
class TestMoeAlignBlockSize(unittest.TestCase):
491643
def setUp(self):
492644
if not torch.cuda.is_available():

0 commit comments

Comments
 (0)