Skip to content

Commit 6204cf4

Browse files
committed
Add W4A8 INT8 activation kernels for batched MoE prefill
INT8 tensor core variants of the batched MoE GEMM kernels that dynamically quantize bf16 activations to INT8 per-row per-tile and dequantize INT4 weights directly to INT8 (skipping bf16 conversion). Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine similarity vs bf16 baseline. Co-authored-by: Claude <noreplyanthropic.com> ghstack-source-id: b89fa45 Pull Request resolved: #19187
1 parent cb4e5ae commit 6204cf4

4 files changed

Lines changed: 611 additions & 4 deletions

File tree

backends/cuda/tests/test_fused_moe.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from executorch.backends.cuda.triton.kernels.fused_moe import (
3232
fused_moe as triton_fused_moe,
3333
fused_moe_batched as triton_fused_moe_batched,
34+
fused_moe_batched_gemm_int8 as triton_fused_moe_batched_int8,
3435
moe_align_block_size,
3536
)
3637
from executorch.exir import (
@@ -487,6 +488,152 @@ def test_e2e_cpp_runner(self):
487488
)
488489

489490

491+
class TestFusedMoEBatchedInt8(unittest.TestCase):
492+
"""Correctness tests for the INT8 dynamic-activation batched MoE kernel."""
493+
494+
INT8_TEST_CONFIGS = [
495+
(42, 8, 64, 32, 4, 2, 32, "8tok_small"),
496+
(7, 16, 64, 32, 8, 4, 32, "16tok_8exp_top4"),
497+
(13, 32, 128, 64, 8, 2, 64, "32tok_gs64"),
498+
(55, 64, 64, 32, 4, 2, 32, "64tok"),
499+
(99, 128, 128, 64, 8, 2, 32, "128tok"),
500+
(0, 256, 128, 64, 8, 2, 32, "256tok"),
501+
]
502+
503+
def test_int8_correctness(self):
504+
"""INT8 batched kernel matches reference across M values."""
505+
for (
506+
seed,
507+
M,
508+
hidden,
509+
intermediate,
510+
num_experts,
511+
top_k,
512+
gs,
513+
desc,
514+
) in self.INT8_TEST_CONFIGS:
515+
with self.subTest(desc=desc):
516+
torch.manual_seed(seed)
517+
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
518+
w1_weight = torch.randn(
519+
num_experts,
520+
2 * intermediate,
521+
hidden,
522+
dtype=torch.bfloat16,
523+
device="cuda",
524+
)
525+
w2_weight = torch.randn(
526+
num_experts,
527+
hidden,
528+
intermediate,
529+
dtype=torch.bfloat16,
530+
device="cuda",
531+
)
532+
w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs)
533+
w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs)
534+
w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda()
535+
536+
scores = torch.randn(M, num_experts, device="cuda")
537+
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
538+
topk_weights = topk_weights.softmax(dim=-1).float()
539+
540+
out_int8 = triton_fused_moe_batched_int8(
541+
x,
542+
w1,
543+
w1s,
544+
w2,
545+
w2s,
546+
topk_weights,
547+
topk_ids,
548+
top_k,
549+
num_experts,
550+
gs,
551+
)
552+
553+
w1_dq = _dequantize_int4(w1.cpu(), w1s.cpu(), gs).cuda()
554+
w2_dq = _dequantize_int4(w2.cpu(), w2s.cpu(), gs).cuda()
555+
ref = _reference_moe(x, w1_dq, w2_dq, topk_weights, topk_ids, top_k)
556+
557+
diff = (out_int8.float() - ref.float()).abs().max().item()
558+
rel = diff / (ref.float().abs().max().item() + 1e-10)
559+
self.assertLess(
560+
rel,
561+
0.10,
562+
f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})",
563+
)
564+
565+
def test_int8_matches_bf16_batched(self):
566+
"""INT8 batched output is close to BF16 batched output."""
567+
for (
568+
seed,
569+
M,
570+
hidden,
571+
intermediate,
572+
num_experts,
573+
top_k,
574+
gs,
575+
desc,
576+
) in self.INT8_TEST_CONFIGS:
577+
with self.subTest(desc=desc):
578+
torch.manual_seed(seed)
579+
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
580+
w1_weight = torch.randn(
581+
num_experts,
582+
2 * intermediate,
583+
hidden,
584+
dtype=torch.bfloat16,
585+
device="cuda",
586+
)
587+
w2_weight = torch.randn(
588+
num_experts,
589+
hidden,
590+
intermediate,
591+
dtype=torch.bfloat16,
592+
device="cuda",
593+
)
594+
w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs)
595+
w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs)
596+
w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda()
597+
598+
scores = torch.randn(M, num_experts, device="cuda")
599+
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
600+
topk_weights = topk_weights.softmax(dim=-1).float()
601+
602+
out_bf16 = triton_fused_moe_batched(
603+
x,
604+
w1,
605+
w1s,
606+
w2,
607+
w2s,
608+
topk_weights,
609+
topk_ids,
610+
top_k,
611+
num_experts,
612+
gs,
613+
)
614+
615+
out_int8 = triton_fused_moe_batched_int8(
616+
x,
617+
w1,
618+
w1s,
619+
w2,
620+
w2s,
621+
topk_weights,
622+
topk_ids,
623+
top_k,
624+
num_experts,
625+
gs,
626+
)
627+
628+
diff = (out_int8.float() - out_bf16.float()).abs().max().item()
629+
rel = diff / (out_bf16.float().abs().max().item() + 1e-10)
630+
self.assertLess(
631+
rel,
632+
0.15,
633+
f"{desc}: int8 vs bf16 relative diff {rel:.4f} (abs {diff:.6f})",
634+
)
635+
636+
490637
class TestMoeAlignBlockSize(unittest.TestCase):
491638
def setUp(self):
492639
if not torch.cuda.is_available():

0 commit comments

Comments
 (0)