|
30 | 30 | from executorch.backends.cuda.cuda_partitioner import CudaPartitioner |
31 | 31 | from executorch.backends.cuda.triton.kernels.fused_moe import ( |
32 | 32 | fused_moe as triton_fused_moe, |
| 33 | + fused_moe_batched as triton_fused_moe_batched, |
| 34 | + moe_align_block_size, |
33 | 35 | ) |
34 | 36 | from executorch.exir import ( |
35 | 37 | EdgeCompileConfig, |
@@ -332,6 +334,96 @@ def test_single_expert(self): |
332 | 334 | rel = diff / (ref.float().abs().max().item() + 1e-10) |
333 | 335 | self.assertLess(rel, 0.05, f"token {t}: relative diff {rel:.4f}") |
334 | 336 |
|
| 337 | + def test_batched_correctness(self): |
| 338 | + """Batched kernel matches reference across M values.""" |
| 339 | + test_cases = [ |
| 340 | + (42, 8, 64, 32, 4, 2, 32, "8tok_small"), |
| 341 | + (7, 16, 64, 32, 8, 4, 32, "16tok_8exp_top4"), |
| 342 | + (13, 32, 128, 64, 8, 2, 64, "32tok_gs64"), |
| 343 | + (55, 64, 64, 32, 4, 2, 32, "64tok"), |
| 344 | + (99, 128, 128, 64, 8, 2, 32, "128tok"), |
| 345 | + ] |
| 346 | + for seed, M, hidden, intermediate, num_experts, top_k, gs, desc in test_cases: |
| 347 | + with self.subTest(desc=desc): |
| 348 | + torch.manual_seed(seed) |
| 349 | + x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") |
| 350 | + w1_weight = torch.randn( |
| 351 | + num_experts, |
| 352 | + 2 * intermediate, |
| 353 | + hidden, |
| 354 | + dtype=torch.bfloat16, |
| 355 | + device="cuda", |
| 356 | + ) |
| 357 | + w2_weight = torch.randn( |
| 358 | + num_experts, |
| 359 | + hidden, |
| 360 | + intermediate, |
| 361 | + dtype=torch.bfloat16, |
| 362 | + device="cuda", |
| 363 | + ) |
| 364 | + w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs) |
| 365 | + w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs) |
| 366 | + w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda() |
| 367 | + |
| 368 | + scores = torch.randn(M, num_experts, device="cuda") |
| 369 | + topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1) |
| 370 | + topk_weights = topk_weights.softmax(dim=-1).float() |
| 371 | + |
| 372 | + out = triton_fused_moe_batched( |
| 373 | + x, |
| 374 | + w1, |
| 375 | + w1s, |
| 376 | + w2, |
| 377 | + w2s, |
| 378 | + topk_weights, |
| 379 | + topk_ids, |
| 380 | + top_k, |
| 381 | + num_experts, |
| 382 | + gs, |
| 383 | + ) |
| 384 | + |
| 385 | + w1_dq = _dequantize_int4(w1.cpu(), w1s.cpu(), gs).cuda() |
| 386 | + w2_dq = _dequantize_int4(w2.cpu(), w2s.cpu(), gs).cuda() |
| 387 | + ref = _reference_moe(x, w1_dq, w2_dq, topk_weights, topk_ids, top_k) |
| 388 | + |
| 389 | + diff = (out.float() - ref.float()).abs().max().item() |
| 390 | + rel = diff / (ref.float().abs().max().item() + 1e-10) |
| 391 | + self.assertLess( |
| 392 | + rel, |
| 393 | + 0.05, |
| 394 | + f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})", |
| 395 | + ) |
| 396 | + |
| 397 | + def test_batched_matches_fused(self): |
| 398 | + """Batched kernel matches the existing fused_moe kernel at Qwen-scale dims.""" |
| 399 | + E, top_k, K, inter, gs = 256, 8, 2048, 512, 128 |
| 400 | + torch.manual_seed(42) |
| 401 | + vals = torch.randint(0, 16, (E, 2 * inter, K), dtype=torch.uint8, device="cuda") |
| 402 | + w1 = ((vals[:, :, 1::2] << 4) | vals[:, :, 0::2]).to(torch.int8) |
| 403 | + w1s = ( |
| 404 | + torch.randn(E, 2 * inter, K // gs, device="cuda", dtype=torch.bfloat16) |
| 405 | + * 0.01 |
| 406 | + ) |
| 407 | + vals = torch.randint(0, 16, (E, K, inter), dtype=torch.uint8, device="cuda") |
| 408 | + w2 = ((vals[:, :, 1::2] << 4) | vals[:, :, 0::2]).to(torch.int8) |
| 409 | + w2s = torch.randn(E, K, inter // gs, device="cuda", dtype=torch.bfloat16) * 0.01 |
| 410 | + |
| 411 | + for M in [16, 64, 256]: |
| 412 | + with self.subTest(M=M): |
| 413 | + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) |
| 414 | + logits = torch.randn(M, E, device="cuda", dtype=torch.float32) |
| 415 | + tw, ti = torch.topk(logits, top_k, dim=-1) |
| 416 | + tw = tw.softmax(dim=-1) |
| 417 | + ti = ti.to(torch.int64) |
| 418 | + |
| 419 | + out_fused = triton_fused_moe(x, w1, w1s, w2, w2s, tw, ti, top_k, E, gs) |
| 420 | + out_batched = triton_fused_moe_batched( |
| 421 | + x, w1, w1s, w2, w2s, tw, ti, top_k, E, gs |
| 422 | + ) |
| 423 | + |
| 424 | + err = (out_fused.float() - out_batched.float()).abs().max().item() |
| 425 | + self.assertLess(err, 0.5, f"M={M}: max abs error {err:.4e}") |
| 426 | + |
335 | 427 | def test_export_cuda(self): |
336 | 428 | """Export succeeds and produces non-empty .pte.""" |
337 | 429 | with tempfile.TemporaryDirectory() as tmpdir: |
@@ -395,6 +487,144 @@ def test_e2e_cpp_runner(self): |
395 | 487 | ) |
396 | 488 |
|
397 | 489 |
|
| 490 | +class TestMoeAlignBlockSize(unittest.TestCase): |
| 491 | + def setUp(self): |
| 492 | + if not torch.cuda.is_available(): |
| 493 | + self.skipTest("CUDA is not available") |
| 494 | + |
| 495 | + def test_basic_correctness(self): |
| 496 | + M, top_k, num_experts, block_size = 4, 2, 4, 4 |
| 497 | + # Token 0 -> experts 0, 1 |
| 498 | + # Token 1 -> experts 2, 3 |
| 499 | + # Token 2 -> experts 0, 2 |
| 500 | + # Token 3 -> experts 1, 3 |
| 501 | + topk_ids = torch.tensor( |
| 502 | + [[0, 1], [2, 3], [0, 2], [1, 3]], dtype=torch.int64, device="cuda" |
| 503 | + ) |
| 504 | + num_pairs = M * top_k |
| 505 | + sentinel = num_pairs |
| 506 | + |
| 507 | + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( |
| 508 | + topk_ids, block_size, num_experts |
| 509 | + ) |
| 510 | + |
| 511 | + max_num_tokens_padded = num_pairs + num_experts * block_size |
| 512 | + max_num_expert_blocks = max_num_tokens_padded // block_size |
| 513 | + self.assertEqual(sorted_token_ids.shape[0], max_num_tokens_padded) |
| 514 | + self.assertEqual(expert_ids.shape[0], max_num_expert_blocks) |
| 515 | + |
| 516 | + # Each expert gets exactly 2 tokens, padded to block_size=4 |
| 517 | + # So num_tokens_post_padded should be 4 * 4 = 16 |
| 518 | + self.assertEqual(num_tokens_post_padded.item(), 16) |
| 519 | + |
| 520 | + # Verify tokens are grouped by expert within the active region |
| 521 | + flat_ids = topk_ids.reshape(-1) |
| 522 | + active = sorted_token_ids[: num_tokens_post_padded.item()] |
| 523 | + for block_idx in range(num_tokens_post_padded.item() // block_size): |
| 524 | + expert = expert_ids[block_idx].item() |
| 525 | + block = active[block_idx * block_size : (block_idx + 1) * block_size] |
| 526 | + for pair_id in block.tolist(): |
| 527 | + if pair_id == sentinel: |
| 528 | + continue |
| 529 | + self.assertEqual( |
| 530 | + flat_ids[pair_id].item(), |
| 531 | + expert, |
| 532 | + f"pair {pair_id} expected expert {expert}, got {flat_ids[pair_id].item()}", |
| 533 | + ) |
| 534 | + |
| 535 | + def test_all_tokens_same_expert(self): |
| 536 | + M, top_k, num_experts, block_size = 4, 2, 4, 4 |
| 537 | + topk_ids = torch.full((M, top_k), 2, dtype=torch.int64, device="cuda") |
| 538 | + num_pairs = M * top_k |
| 539 | + sentinel = num_pairs |
| 540 | + |
| 541 | + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( |
| 542 | + topk_ids, block_size, num_experts |
| 543 | + ) |
| 544 | + |
| 545 | + # All 8 pairs go to expert 2, padded to block_size=4 -> 8 slots |
| 546 | + self.assertEqual(num_tokens_post_padded.item(), 8) |
| 547 | + |
| 548 | + active = sorted_token_ids[: num_tokens_post_padded.item()] |
| 549 | + real_ids = active[active != sentinel] |
| 550 | + self.assertEqual(real_ids.shape[0], num_pairs) |
| 551 | + self.assertTrue( |
| 552 | + (sorted(real_ids.tolist()) == list(range(num_pairs))), |
| 553 | + "All pair indices should appear exactly once", |
| 554 | + ) |
| 555 | + |
| 556 | + def test_single_token(self): |
| 557 | + num_experts, block_size = 4, 4 |
| 558 | + topk_ids = torch.tensor([[2]], dtype=torch.int64, device="cuda") |
| 559 | + num_pairs = 1 |
| 560 | + sentinel = num_pairs |
| 561 | + |
| 562 | + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( |
| 563 | + topk_ids, block_size, num_experts |
| 564 | + ) |
| 565 | + |
| 566 | + # 1 token to expert 2, padded to block_size=4 |
| 567 | + self.assertEqual(num_tokens_post_padded.item(), block_size) |
| 568 | + |
| 569 | + active = sorted_token_ids[: num_tokens_post_padded.item()] |
| 570 | + real_ids = active[active != sentinel].tolist() |
| 571 | + self.assertEqual(real_ids, [0]) |
| 572 | + sentinel_count = (active == sentinel).sum().item() |
| 573 | + self.assertEqual(sentinel_count, block_size - 1) |
| 574 | + |
| 575 | + def test_num_pairs_less_than_block_size(self): |
| 576 | + M, top_k, num_experts, block_size = 1, 2, 4, 16 |
| 577 | + topk_ids = torch.tensor([[0, 3]], dtype=torch.int64, device="cuda") |
| 578 | + num_pairs = M * top_k |
| 579 | + sentinel = num_pairs |
| 580 | + |
| 581 | + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( |
| 582 | + topk_ids, block_size, num_experts |
| 583 | + ) |
| 584 | + |
| 585 | + # 1 token per expert -> each padded to block_size=16, total=32 |
| 586 | + self.assertEqual(num_tokens_post_padded.item(), 2 * block_size) |
| 587 | + |
| 588 | + active = sorted_token_ids[: num_tokens_post_padded.item()] |
| 589 | + real_ids = sorted(active[active != sentinel].tolist()) |
| 590 | + self.assertEqual(real_ids, [0, 1]) |
| 591 | + |
| 592 | + def test_sentinel_value(self): |
| 593 | + M, top_k, num_experts, block_size = 2, 2, 4, 4 |
| 594 | + topk_ids = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cuda") |
| 595 | + num_pairs = M * top_k |
| 596 | + sentinel = num_pairs |
| 597 | + |
| 598 | + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( |
| 599 | + topk_ids, block_size, num_experts |
| 600 | + ) |
| 601 | + |
| 602 | + # Padding positions within the active region use sentinel = num_pairs |
| 603 | + active = sorted_token_ids[: num_tokens_post_padded.item()] |
| 604 | + for val in active.tolist(): |
| 605 | + self.assertTrue( |
| 606 | + 0 <= val <= sentinel, |
| 607 | + f"Value {val} outside valid range [0, {sentinel}]", |
| 608 | + ) |
| 609 | + |
| 610 | + # Tail beyond active region should also be sentinel |
| 611 | + tail = sorted_token_ids[num_tokens_post_padded.item() :] |
| 612 | + self.assertTrue((tail == sentinel).all()) |
| 613 | + |
| 614 | + def test_determinism(self): |
| 615 | + M, top_k, num_experts, block_size = 8, 4, 8, 4 |
| 616 | + torch.manual_seed(42) |
| 617 | + topk_ids = torch.randint(0, num_experts, (M, top_k), device="cuda") |
| 618 | + |
| 619 | + results = [ |
| 620 | + moe_align_block_size(topk_ids, block_size, num_experts) for _ in range(5) |
| 621 | + ] |
| 622 | + for i in range(1, len(results)): |
| 623 | + self.assertTrue(torch.equal(results[0][0], results[i][0])) |
| 624 | + self.assertTrue(torch.equal(results[0][1], results[i][1])) |
| 625 | + self.assertEqual(results[0][2].item(), results[i][2].item()) |
| 626 | + |
| 627 | + |
398 | 628 | def _dequantize_int4(packed, scale, group_size): |
399 | 629 | """Dequantize packed INT4 [E, N, K//2] back to [E, N, K] float.""" |
400 | 630 | E, N, K_half = packed.shape |
|
0 commit comments