Skip to content

Commit 594009d

Browse files
committed
Update on "Add W4A8 INT8 activation kernels for batched MoE prefill"
INT8 tensor core variants of the batched MoE GEMM kernels that dynamically quantize bf16 activations to INT8 per-row per-tile and dequantize INT4 weights directly to INT8 (skipping bf16 conversion). Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine similarity vs bf16 baseline. Co-authored-by: Claude <noreplyanthropic.com> [ghstack-poisoned]
1 parent a6aba5e commit 594009d

2 files changed

Lines changed: 147 additions & 37 deletions

File tree

backends/cuda/benchmarks/benchmark_moe.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -247,34 +247,6 @@ def _run_triton_batched(
247247
)
248248

249249
BACKENDS["triton_batched"] = ("Triton batched", _run_triton_batched)
250-
251-
def _run_triton_batched_int8(
252-
hidden_states,
253-
w1,
254-
w1_scale,
255-
w2,
256-
w2_scale,
257-
topk_weights,
258-
topk_ids,
259-
top_k,
260-
num_experts,
261-
group_size,
262-
):
263-
return fused_moe_batched(
264-
hidden_states,
265-
w1,
266-
w1_scale,
267-
w2,
268-
w2_scale,
269-
topk_weights,
270-
topk_ids,
271-
top_k=top_k,
272-
num_experts=num_experts,
273-
group_size=group_size,
274-
activation_dtype="int8",
275-
)
276-
277-
BACKENDS["triton_batched_int8"] = ("Triton bat-i8", _run_triton_batched_int8)
278250
except ImportError:
279251
pass
280252

@@ -386,15 +358,6 @@ def run_benchmark(
386358
f"Triton vs eager mismatch at M={M}: "
387359
f"max abs error {err:.3e} >= 2.0e-1"
388360
)
389-
if "triton_batched_int8" in BACKENDS:
390-
_, _, run_int8 = BACKENDS["triton_batched_int8"]
391-
int8_out = run_int8(**common_args)
392-
int8_err = _max_abs_error(int8_out, ref_out)
393-
assert int8_err < 5.0e-1, (
394-
f"Triton INT8 vs eager mismatch at M={M}: "
395-
f"max abs error {int8_err:.3e} >= 5.0e-1"
396-
)
397-
del int8_out
398361
del ref_out, tri_out
399362

400363
# Benchmark

backends/cuda/tests/test_fused_moe.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from executorch.backends.cuda.triton.kernels.fused_moe import (
3232
fused_moe as triton_fused_moe,
3333
fused_moe_batched as triton_fused_moe_batched,
34+
fused_moe_batched_gemm_int8 as triton_fused_moe_batched_int8,
3435
moe_align_block_size,
3536
)
3637
from executorch.exir import (
@@ -487,6 +488,152 @@ def test_e2e_cpp_runner(self):
487488
)
488489

489490

491+
class TestFusedMoEBatchedInt8(unittest.TestCase):
492+
"""Correctness tests for the INT8 dynamic-activation batched MoE kernel."""
493+
494+
INT8_TEST_CONFIGS = [
495+
(42, 8, 64, 32, 4, 2, 32, "8tok_small"),
496+
(7, 16, 64, 32, 8, 4, 32, "16tok_8exp_top4"),
497+
(13, 32, 128, 64, 8, 2, 64, "32tok_gs64"),
498+
(55, 64, 64, 32, 4, 2, 32, "64tok"),
499+
(99, 128, 128, 64, 8, 2, 32, "128tok"),
500+
(0, 256, 128, 64, 8, 2, 32, "256tok"),
501+
]
502+
503+
def test_int8_correctness(self):
504+
"""INT8 batched kernel matches reference across M values."""
505+
for (
506+
seed,
507+
M,
508+
hidden,
509+
intermediate,
510+
num_experts,
511+
top_k,
512+
gs,
513+
desc,
514+
) in self.INT8_TEST_CONFIGS:
515+
with self.subTest(desc=desc):
516+
torch.manual_seed(seed)
517+
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
518+
w1_weight = torch.randn(
519+
num_experts,
520+
2 * intermediate,
521+
hidden,
522+
dtype=torch.bfloat16,
523+
device="cuda",
524+
)
525+
w2_weight = torch.randn(
526+
num_experts,
527+
hidden,
528+
intermediate,
529+
dtype=torch.bfloat16,
530+
device="cuda",
531+
)
532+
w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs)
533+
w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs)
534+
w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda()
535+
536+
scores = torch.randn(M, num_experts, device="cuda")
537+
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
538+
topk_weights = topk_weights.softmax(dim=-1).float()
539+
540+
out_int8 = triton_fused_moe_batched_int8(
541+
x,
542+
w1,
543+
w1s,
544+
w2,
545+
w2s,
546+
topk_weights,
547+
topk_ids,
548+
top_k,
549+
num_experts,
550+
gs,
551+
)
552+
553+
w1_dq = _dequantize_int4(w1.cpu(), w1s.cpu(), gs).cuda()
554+
w2_dq = _dequantize_int4(w2.cpu(), w2s.cpu(), gs).cuda()
555+
ref = _reference_moe(x, w1_dq, w2_dq, topk_weights, topk_ids, top_k)
556+
557+
diff = (out_int8.float() - ref.float()).abs().max().item()
558+
rel = diff / (ref.float().abs().max().item() + 1e-10)
559+
self.assertLess(
560+
rel,
561+
0.10,
562+
f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})",
563+
)
564+
565+
def test_int8_matches_bf16_batched(self):
566+
"""INT8 batched output is close to BF16 batched output."""
567+
for (
568+
seed,
569+
M,
570+
hidden,
571+
intermediate,
572+
num_experts,
573+
top_k,
574+
gs,
575+
desc,
576+
) in self.INT8_TEST_CONFIGS:
577+
with self.subTest(desc=desc):
578+
torch.manual_seed(seed)
579+
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
580+
w1_weight = torch.randn(
581+
num_experts,
582+
2 * intermediate,
583+
hidden,
584+
dtype=torch.bfloat16,
585+
device="cuda",
586+
)
587+
w2_weight = torch.randn(
588+
num_experts,
589+
hidden,
590+
intermediate,
591+
dtype=torch.bfloat16,
592+
device="cuda",
593+
)
594+
w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs)
595+
w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs)
596+
w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda()
597+
598+
scores = torch.randn(M, num_experts, device="cuda")
599+
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
600+
topk_weights = topk_weights.softmax(dim=-1).float()
601+
602+
out_bf16 = triton_fused_moe_batched(
603+
x,
604+
w1,
605+
w1s,
606+
w2,
607+
w2s,
608+
topk_weights,
609+
topk_ids,
610+
top_k,
611+
num_experts,
612+
gs,
613+
)
614+
615+
out_int8 = triton_fused_moe_batched_int8(
616+
x,
617+
w1,
618+
w1s,
619+
w2,
620+
w2s,
621+
topk_weights,
622+
topk_ids,
623+
top_k,
624+
num_experts,
625+
gs,
626+
)
627+
628+
diff = (out_int8.float() - out_bf16.float()).abs().max().item()
629+
rel = diff / (out_bf16.float().abs().max().item() + 1e-10)
630+
self.assertLess(
631+
rel,
632+
0.15,
633+
f"{desc}: int8 vs bf16 relative diff {rel:.4f} (abs {diff:.6f})",
634+
)
635+
636+
490637
class TestMoeAlignBlockSize(unittest.TestCase):
491638
def setUp(self):
492639
if not torch.cuda.is_available():

0 commit comments

Comments
 (0)