Skip to content

Commit f3e49ff

Browse files
Add W4A8 INT8 activation kernels for batched MoE prefill (#19226)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #19187 by @digantdesai ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/digantdesai/50/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/digantdesai/50/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/digantdesai/50/orig @diff-train-skip-merge Co-authored-by: Digant Desai <digantdesai@meta.com>
1 parent 798c121 commit f3e49ff

5 files changed

Lines changed: 618 additions & 5 deletions

File tree

.ci/scripts/export_model_artifact.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,8 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then
418418
TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \
419419
python -m executorch.examples.models.qwen3_5_moe.export \
420420
--prequantized "$LOCAL_MODEL_DIR" \
421-
--output-dir "${OUTPUT_DIR}"
421+
--output-dir "${OUTPUT_DIR}" \
422+
--moe-activation-dtype int8
422423
echo "::endgroup::"
423424

424425
test -f "${OUTPUT_DIR}/model.pte"

backends/cuda/tests/test_fused_moe.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from executorch.backends.cuda.triton.kernels.fused_moe import (
3232
fused_moe as triton_fused_moe,
3333
fused_moe_batched as triton_fused_moe_batched,
34+
fused_moe_batched_gemm_int8 as triton_fused_moe_batched_int8,
3435
moe_align_block_size,
3536
)
3637
from executorch.exir import (
@@ -212,6 +213,11 @@ def _run_cpp_runner(runner_path, pte_path, ptd_path, input_files, output_base):
212213

213214

214215
class TestFusedMoE(unittest.TestCase):
216+
# TODO: migrate from manual max_abs/max_ref relative checks to
217+
# torch.allclose(atol=, rtol=). Current tests use per-tensor-max relative
218+
# error which is looser than per-element allclose — need to calibrate atol
219+
# for INT4 quantization noise floor across random weight magnitudes.
220+
215221
def setUp(self):
216222
if not torch.cuda.is_available():
217223
self.skipTest("CUDA is not available")
@@ -487,6 +493,152 @@ def test_e2e_cpp_runner(self):
487493
)
488494

489495

496+
class TestFusedMoEBatchedInt8(unittest.TestCase):
497+
"""Correctness tests for the INT8 dynamic-activation batched MoE kernel."""
498+
499+
INT8_TEST_CONFIGS = [
500+
(42, 8, 64, 32, 4, 2, 32, "8tok_small"),
501+
(7, 16, 64, 32, 8, 4, 32, "16tok_8exp_top4"),
502+
(13, 32, 128, 64, 8, 2, 64, "32tok_gs64"),
503+
(55, 64, 64, 32, 4, 2, 32, "64tok"),
504+
(99, 128, 128, 64, 8, 2, 32, "128tok"),
505+
(0, 256, 128, 64, 8, 2, 32, "256tok"),
506+
]
507+
508+
def test_int8_correctness(self):
509+
"""INT8 batched kernel matches reference across M values."""
510+
for (
511+
seed,
512+
M,
513+
hidden,
514+
intermediate,
515+
num_experts,
516+
top_k,
517+
gs,
518+
desc,
519+
) in self.INT8_TEST_CONFIGS:
520+
with self.subTest(desc=desc):
521+
torch.manual_seed(seed)
522+
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
523+
w1_weight = torch.randn(
524+
num_experts,
525+
2 * intermediate,
526+
hidden,
527+
dtype=torch.bfloat16,
528+
device="cuda",
529+
)
530+
w2_weight = torch.randn(
531+
num_experts,
532+
hidden,
533+
intermediate,
534+
dtype=torch.bfloat16,
535+
device="cuda",
536+
)
537+
w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs)
538+
w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs)
539+
w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda()
540+
541+
scores = torch.randn(M, num_experts, device="cuda")
542+
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
543+
topk_weights = topk_weights.softmax(dim=-1).float()
544+
545+
out_int8 = triton_fused_moe_batched_int8(
546+
x,
547+
w1,
548+
w1s,
549+
w2,
550+
w2s,
551+
topk_weights,
552+
topk_ids,
553+
top_k,
554+
num_experts,
555+
gs,
556+
)
557+
558+
w1_dq = _dequantize_int4(w1.cpu(), w1s.cpu(), gs).cuda()
559+
w2_dq = _dequantize_int4(w2.cpu(), w2s.cpu(), gs).cuda()
560+
ref = _reference_moe(x, w1_dq, w2_dq, topk_weights, topk_ids, top_k)
561+
562+
diff = (out_int8.float() - ref.float()).abs().max().item()
563+
rel = diff / (ref.float().abs().max().item() + 1e-10)
564+
self.assertLess(
565+
rel,
566+
0.10,
567+
f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})",
568+
)
569+
570+
def test_int8_matches_bf16_batched(self):
571+
"""INT8 batched output is close to BF16 batched output."""
572+
for (
573+
seed,
574+
M,
575+
hidden,
576+
intermediate,
577+
num_experts,
578+
top_k,
579+
gs,
580+
desc,
581+
) in self.INT8_TEST_CONFIGS:
582+
with self.subTest(desc=desc):
583+
torch.manual_seed(seed)
584+
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
585+
w1_weight = torch.randn(
586+
num_experts,
587+
2 * intermediate,
588+
hidden,
589+
dtype=torch.bfloat16,
590+
device="cuda",
591+
)
592+
w2_weight = torch.randn(
593+
num_experts,
594+
hidden,
595+
intermediate,
596+
dtype=torch.bfloat16,
597+
device="cuda",
598+
)
599+
w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs)
600+
w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs)
601+
w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda()
602+
603+
scores = torch.randn(M, num_experts, device="cuda")
604+
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
605+
topk_weights = topk_weights.softmax(dim=-1).float()
606+
607+
out_bf16 = triton_fused_moe_batched(
608+
x,
609+
w1,
610+
w1s,
611+
w2,
612+
w2s,
613+
topk_weights,
614+
topk_ids,
615+
top_k,
616+
num_experts,
617+
gs,
618+
)
619+
620+
out_int8 = triton_fused_moe_batched_int8(
621+
x,
622+
w1,
623+
w1s,
624+
w2,
625+
w2s,
626+
topk_weights,
627+
topk_ids,
628+
top_k,
629+
num_experts,
630+
gs,
631+
)
632+
633+
diff = (out_int8.float() - out_bf16.float()).abs().max().item()
634+
rel = diff / (out_bf16.float().abs().max().item() + 1e-10)
635+
self.assertLess(
636+
rel,
637+
0.15,
638+
f"{desc}: int8 vs bf16 relative diff {rel:.4f} (abs {diff:.6f})",
639+
)
640+
641+
490642
class TestMoeAlignBlockSize(unittest.TestCase):
491643
def setUp(self):
492644
if not torch.cuda.is_available():

0 commit comments

Comments
 (0)