|
31 | 31 | from executorch.backends.cuda.triton.kernels.fused_moe import ( |
32 | 32 | fused_moe as triton_fused_moe, |
33 | 33 | fused_moe_batched as triton_fused_moe_batched, |
| 34 | + fused_moe_batched_gemm_int8 as triton_fused_moe_batched_int8, |
34 | 35 | moe_align_block_size, |
35 | 36 | ) |
36 | 37 | from executorch.exir import ( |
@@ -487,6 +488,152 @@ def test_e2e_cpp_runner(self): |
487 | 488 | ) |
488 | 489 |
|
489 | 490 |
|
| 491 | +class TestFusedMoEBatchedInt8(unittest.TestCase): |
| 492 | + """Correctness tests for the INT8 dynamic-activation batched MoE kernel.""" |
| 493 | + |
| 494 | + INT8_TEST_CONFIGS = [ |
| 495 | + (42, 8, 64, 32, 4, 2, 32, "8tok_small"), |
| 496 | + (7, 16, 64, 32, 8, 4, 32, "16tok_8exp_top4"), |
| 497 | + (13, 32, 128, 64, 8, 2, 64, "32tok_gs64"), |
| 498 | + (55, 64, 64, 32, 4, 2, 32, "64tok"), |
| 499 | + (99, 128, 128, 64, 8, 2, 32, "128tok"), |
| 500 | + (0, 256, 128, 64, 8, 2, 32, "256tok"), |
| 501 | + ] |
| 502 | + |
| 503 | + def test_int8_correctness(self): |
| 504 | + """INT8 batched kernel matches reference across M values.""" |
| 505 | + for ( |
| 506 | + seed, |
| 507 | + M, |
| 508 | + hidden, |
| 509 | + intermediate, |
| 510 | + num_experts, |
| 511 | + top_k, |
| 512 | + gs, |
| 513 | + desc, |
| 514 | + ) in self.INT8_TEST_CONFIGS: |
| 515 | + with self.subTest(desc=desc): |
| 516 | + torch.manual_seed(seed) |
| 517 | + x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") |
| 518 | + w1_weight = torch.randn( |
| 519 | + num_experts, |
| 520 | + 2 * intermediate, |
| 521 | + hidden, |
| 522 | + dtype=torch.bfloat16, |
| 523 | + device="cuda", |
| 524 | + ) |
| 525 | + w2_weight = torch.randn( |
| 526 | + num_experts, |
| 527 | + hidden, |
| 528 | + intermediate, |
| 529 | + dtype=torch.bfloat16, |
| 530 | + device="cuda", |
| 531 | + ) |
| 532 | + w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs) |
| 533 | + w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs) |
| 534 | + w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda() |
| 535 | + |
| 536 | + scores = torch.randn(M, num_experts, device="cuda") |
| 537 | + topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1) |
| 538 | + topk_weights = topk_weights.softmax(dim=-1).float() |
| 539 | + |
| 540 | + out_int8 = triton_fused_moe_batched_int8( |
| 541 | + x, |
| 542 | + w1, |
| 543 | + w1s, |
| 544 | + w2, |
| 545 | + w2s, |
| 546 | + topk_weights, |
| 547 | + topk_ids, |
| 548 | + top_k, |
| 549 | + num_experts, |
| 550 | + gs, |
| 551 | + ) |
| 552 | + |
| 553 | + w1_dq = _dequantize_int4(w1.cpu(), w1s.cpu(), gs).cuda() |
| 554 | + w2_dq = _dequantize_int4(w2.cpu(), w2s.cpu(), gs).cuda() |
| 555 | + ref = _reference_moe(x, w1_dq, w2_dq, topk_weights, topk_ids, top_k) |
| 556 | + |
| 557 | + diff = (out_int8.float() - ref.float()).abs().max().item() |
| 558 | + rel = diff / (ref.float().abs().max().item() + 1e-10) |
| 559 | + self.assertLess( |
| 560 | + rel, |
| 561 | + 0.10, |
| 562 | + f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})", |
| 563 | + ) |
| 564 | + |
| 565 | + def test_int8_matches_bf16_batched(self): |
| 566 | + """INT8 batched output is close to BF16 batched output.""" |
| 567 | + for ( |
| 568 | + seed, |
| 569 | + M, |
| 570 | + hidden, |
| 571 | + intermediate, |
| 572 | + num_experts, |
| 573 | + top_k, |
| 574 | + gs, |
| 575 | + desc, |
| 576 | + ) in self.INT8_TEST_CONFIGS: |
| 577 | + with self.subTest(desc=desc): |
| 578 | + torch.manual_seed(seed) |
| 579 | + x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") |
| 580 | + w1_weight = torch.randn( |
| 581 | + num_experts, |
| 582 | + 2 * intermediate, |
| 583 | + hidden, |
| 584 | + dtype=torch.bfloat16, |
| 585 | + device="cuda", |
| 586 | + ) |
| 587 | + w2_weight = torch.randn( |
| 588 | + num_experts, |
| 589 | + hidden, |
| 590 | + intermediate, |
| 591 | + dtype=torch.bfloat16, |
| 592 | + device="cuda", |
| 593 | + ) |
| 594 | + w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs) |
| 595 | + w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs) |
| 596 | + w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda() |
| 597 | + |
| 598 | + scores = torch.randn(M, num_experts, device="cuda") |
| 599 | + topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1) |
| 600 | + topk_weights = topk_weights.softmax(dim=-1).float() |
| 601 | + |
| 602 | + out_bf16 = triton_fused_moe_batched( |
| 603 | + x, |
| 604 | + w1, |
| 605 | + w1s, |
| 606 | + w2, |
| 607 | + w2s, |
| 608 | + topk_weights, |
| 609 | + topk_ids, |
| 610 | + top_k, |
| 611 | + num_experts, |
| 612 | + gs, |
| 613 | + ) |
| 614 | + |
| 615 | + out_int8 = triton_fused_moe_batched_int8( |
| 616 | + x, |
| 617 | + w1, |
| 618 | + w1s, |
| 619 | + w2, |
| 620 | + w2s, |
| 621 | + topk_weights, |
| 622 | + topk_ids, |
| 623 | + top_k, |
| 624 | + num_experts, |
| 625 | + gs, |
| 626 | + ) |
| 627 | + |
| 628 | + diff = (out_int8.float() - out_bf16.float()).abs().max().item() |
| 629 | + rel = diff / (out_bf16.float().abs().max().item() + 1e-10) |
| 630 | + self.assertLess( |
| 631 | + rel, |
| 632 | + 0.15, |
| 633 | + f"{desc}: int8 vs bf16 relative diff {rel:.4f} (abs {diff:.6f})", |
| 634 | + ) |
| 635 | + |
| 636 | + |
490 | 637 | class TestMoeAlignBlockSize(unittest.TestCase): |
491 | 638 | def setUp(self): |
492 | 639 | if not torch.cuda.is_available(): |
|
0 commit comments