Skip to content

Commit e285a2e

Browse files
committed
Add batched tensor-core MoE kernel for prefill
Adds a batched (M>1) Triton fused MoE kernel using tensor-core mma instructions for prefill workloads. Includes moe_align_block_size for token-expert sorting and scale broadcast optimization in the batched GEMM inner loops. Weight layout: [E, N, K//2] (packed INT4). This PR was authored with the assistance of Claude.
1 parent 61361ba commit e285a2e

3 files changed

Lines changed: 805 additions & 10 deletions

File tree

backends/cuda/tests/test_fused_moe.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
3131
from executorch.backends.cuda.triton.kernels.fused_moe import (
3232
fused_moe as triton_fused_moe,
33+
fused_moe_batched as triton_fused_moe_batched,
34+
moe_align_block_size,
3335
)
3436
from executorch.exir import (
3537
EdgeCompileConfig,
@@ -332,6 +334,96 @@ def test_single_expert(self):
332334
rel = diff / (ref.float().abs().max().item() + 1e-10)
333335
self.assertLess(rel, 0.05, f"token {t}: relative diff {rel:.4f}")
334336

337+
def test_batched_correctness(self):
338+
"""Batched kernel matches reference across M values."""
339+
test_cases = [
340+
(42, 8, 64, 32, 4, 2, 32, "8tok_small"),
341+
(7, 16, 64, 32, 8, 4, 32, "16tok_8exp_top4"),
342+
(13, 32, 128, 64, 8, 2, 64, "32tok_gs64"),
343+
(55, 64, 64, 32, 4, 2, 32, "64tok"),
344+
(99, 128, 128, 64, 8, 2, 32, "128tok"),
345+
]
346+
for seed, M, hidden, intermediate, num_experts, top_k, gs, desc in test_cases:
347+
with self.subTest(desc=desc):
348+
torch.manual_seed(seed)
349+
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
350+
w1_weight = torch.randn(
351+
num_experts,
352+
2 * intermediate,
353+
hidden,
354+
dtype=torch.bfloat16,
355+
device="cuda",
356+
)
357+
w2_weight = torch.randn(
358+
num_experts,
359+
hidden,
360+
intermediate,
361+
dtype=torch.bfloat16,
362+
device="cuda",
363+
)
364+
w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs)
365+
w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs)
366+
w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda()
367+
368+
scores = torch.randn(M, num_experts, device="cuda")
369+
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
370+
topk_weights = topk_weights.softmax(dim=-1).float()
371+
372+
out = triton_fused_moe_batched(
373+
x,
374+
w1,
375+
w1s,
376+
w2,
377+
w2s,
378+
topk_weights,
379+
topk_ids,
380+
top_k,
381+
num_experts,
382+
gs,
383+
)
384+
385+
w1_dq = _dequantize_int4(w1.cpu(), w1s.cpu(), gs).cuda()
386+
w2_dq = _dequantize_int4(w2.cpu(), w2s.cpu(), gs).cuda()
387+
ref = _reference_moe(x, w1_dq, w2_dq, topk_weights, topk_ids, top_k)
388+
389+
diff = (out.float() - ref.float()).abs().max().item()
390+
rel = diff / (ref.float().abs().max().item() + 1e-10)
391+
self.assertLess(
392+
rel,
393+
0.05,
394+
f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})",
395+
)
396+
397+
def test_batched_matches_fused(self):
398+
"""Batched kernel matches the existing fused_moe kernel at Qwen-scale dims."""
399+
E, top_k, K, inter, gs = 256, 8, 2048, 512, 128
400+
torch.manual_seed(42)
401+
vals = torch.randint(0, 16, (E, 2 * inter, K), dtype=torch.uint8, device="cuda")
402+
w1 = ((vals[:, :, 1::2] << 4) | vals[:, :, 0::2]).to(torch.int8)
403+
w1s = (
404+
torch.randn(E, 2 * inter, K // gs, device="cuda", dtype=torch.bfloat16)
405+
* 0.01
406+
)
407+
vals = torch.randint(0, 16, (E, K, inter), dtype=torch.uint8, device="cuda")
408+
w2 = ((vals[:, :, 1::2] << 4) | vals[:, :, 0::2]).to(torch.int8)
409+
w2s = torch.randn(E, K, inter // gs, device="cuda", dtype=torch.bfloat16) * 0.01
410+
411+
for M in [16, 64, 256]:
412+
with self.subTest(M=M):
413+
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
414+
logits = torch.randn(M, E, device="cuda", dtype=torch.float32)
415+
tw, ti = torch.topk(logits, top_k, dim=-1)
416+
tw = tw.softmax(dim=-1)
417+
ti = ti.to(torch.int64)
418+
419+
out_fused = triton_fused_moe(x, w1, w1s, w2, w2s, tw, ti, top_k, E, gs)
420+
out_batched = triton_fused_moe_batched(
421+
x, w1, w1s, w2, w2s, tw, ti, top_k, E, gs
422+
)
423+
424+
err = (out_fused.float() - out_batched.float()).abs().max().item()
425+
self.assertLess(err, 0.5, f"M={M}: max abs error {err:.4e}")
426+
335427
def test_export_cuda(self):
336428
"""Export succeeds and produces non-empty .pte."""
337429
with tempfile.TemporaryDirectory() as tmpdir:
@@ -395,6 +487,144 @@ def test_e2e_cpp_runner(self):
395487
)
396488

397489

490+
class TestMoeAlignBlockSize(unittest.TestCase):
491+
def setUp(self):
492+
if not torch.cuda.is_available():
493+
self.skipTest("CUDA is not available")
494+
495+
def test_basic_correctness(self):
496+
M, top_k, num_experts, block_size = 4, 2, 4, 4
497+
# Token 0 -> experts 0, 1
498+
# Token 1 -> experts 2, 3
499+
# Token 2 -> experts 0, 2
500+
# Token 3 -> experts 1, 3
501+
topk_ids = torch.tensor(
502+
[[0, 1], [2, 3], [0, 2], [1, 3]], dtype=torch.int64, device="cuda"
503+
)
504+
num_pairs = M * top_k
505+
sentinel = num_pairs
506+
507+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
508+
topk_ids, block_size, num_experts
509+
)
510+
511+
max_num_tokens_padded = num_pairs + num_experts * block_size
512+
max_num_expert_blocks = max_num_tokens_padded // block_size
513+
self.assertEqual(sorted_token_ids.shape[0], max_num_tokens_padded)
514+
self.assertEqual(expert_ids.shape[0], max_num_expert_blocks)
515+
516+
# Each expert gets exactly 2 tokens, padded to block_size=4
517+
# So num_tokens_post_padded should be 4 * 4 = 16
518+
self.assertEqual(num_tokens_post_padded.item(), 16)
519+
520+
# Verify tokens are grouped by expert within the active region
521+
flat_ids = topk_ids.reshape(-1)
522+
active = sorted_token_ids[: num_tokens_post_padded.item()]
523+
for block_idx in range(num_tokens_post_padded.item() // block_size):
524+
expert = expert_ids[block_idx].item()
525+
block = active[block_idx * block_size : (block_idx + 1) * block_size]
526+
for pair_id in block.tolist():
527+
if pair_id == sentinel:
528+
continue
529+
self.assertEqual(
530+
flat_ids[pair_id].item(),
531+
expert,
532+
f"pair {pair_id} expected expert {expert}, got {flat_ids[pair_id].item()}",
533+
)
534+
535+
def test_all_tokens_same_expert(self):
536+
M, top_k, num_experts, block_size = 4, 2, 4, 4
537+
topk_ids = torch.full((M, top_k), 2, dtype=torch.int64, device="cuda")
538+
num_pairs = M * top_k
539+
sentinel = num_pairs
540+
541+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
542+
topk_ids, block_size, num_experts
543+
)
544+
545+
# All 8 pairs go to expert 2, padded to block_size=4 -> 8 slots
546+
self.assertEqual(num_tokens_post_padded.item(), 8)
547+
548+
active = sorted_token_ids[: num_tokens_post_padded.item()]
549+
real_ids = active[active != sentinel]
550+
self.assertEqual(real_ids.shape[0], num_pairs)
551+
self.assertTrue(
552+
(sorted(real_ids.tolist()) == list(range(num_pairs))),
553+
"All pair indices should appear exactly once",
554+
)
555+
556+
def test_single_token(self):
557+
num_experts, block_size = 4, 4
558+
topk_ids = torch.tensor([[2]], dtype=torch.int64, device="cuda")
559+
num_pairs = 1
560+
sentinel = num_pairs
561+
562+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
563+
topk_ids, block_size, num_experts
564+
)
565+
566+
# 1 token to expert 2, padded to block_size=4
567+
self.assertEqual(num_tokens_post_padded.item(), block_size)
568+
569+
active = sorted_token_ids[: num_tokens_post_padded.item()]
570+
real_ids = active[active != sentinel].tolist()
571+
self.assertEqual(real_ids, [0])
572+
sentinel_count = (active == sentinel).sum().item()
573+
self.assertEqual(sentinel_count, block_size - 1)
574+
575+
def test_num_pairs_less_than_block_size(self):
576+
M, top_k, num_experts, block_size = 1, 2, 4, 16
577+
topk_ids = torch.tensor([[0, 3]], dtype=torch.int64, device="cuda")
578+
num_pairs = M * top_k
579+
sentinel = num_pairs
580+
581+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
582+
topk_ids, block_size, num_experts
583+
)
584+
585+
# 1 token per expert -> each padded to block_size=16, total=32
586+
self.assertEqual(num_tokens_post_padded.item(), 2 * block_size)
587+
588+
active = sorted_token_ids[: num_tokens_post_padded.item()]
589+
real_ids = sorted(active[active != sentinel].tolist())
590+
self.assertEqual(real_ids, [0, 1])
591+
592+
def test_sentinel_value(self):
593+
M, top_k, num_experts, block_size = 2, 2, 4, 4
594+
topk_ids = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cuda")
595+
num_pairs = M * top_k
596+
sentinel = num_pairs
597+
598+
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
599+
topk_ids, block_size, num_experts
600+
)
601+
602+
# Padding positions within the active region use sentinel = num_pairs
603+
active = sorted_token_ids[: num_tokens_post_padded.item()]
604+
for val in active.tolist():
605+
self.assertTrue(
606+
0 <= val <= sentinel,
607+
f"Value {val} outside valid range [0, {sentinel}]",
608+
)
609+
610+
# Tail beyond active region should also be sentinel
611+
tail = sorted_token_ids[num_tokens_post_padded.item() :]
612+
self.assertTrue((tail == sentinel).all())
613+
614+
def test_determinism(self):
615+
M, top_k, num_experts, block_size = 8, 4, 8, 4
616+
torch.manual_seed(42)
617+
topk_ids = torch.randint(0, num_experts, (M, top_k), device="cuda")
618+
619+
results = [
620+
moe_align_block_size(topk_ids, block_size, num_experts) for _ in range(5)
621+
]
622+
for i in range(1, len(results)):
623+
self.assertTrue(torch.equal(results[0][0], results[i][0]))
624+
self.assertTrue(torch.equal(results[0][1], results[i][1]))
625+
self.assertEqual(results[0][2].item(), results[i][2].item())
626+
627+
398628
def _dequantize_int4(packed, scale, group_size):
399629
"""Dequantize packed INT4 [E, N, K//2] back to [E, N, K] float."""
400630
E, N, K_half = packed.shape

backends/cuda/triton/kernels/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from executorch.backends.cuda.triton.kernels.fused_moe import fused_moe
7+
from executorch.backends.cuda.triton.kernels.fused_moe import (
8+
fused_moe,
9+
fused_moe_batched,
10+
fused_moe_batched_gemm,
11+
moe_align_block_size,
12+
)
813
from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk
914
from executorch.backends.cuda.triton.kernels.topk import topk
1015

1116
__all__ = [
1217
"fused_moe",
18+
"fused_moe_batched",
19+
"fused_moe_batched_gemm",
20+
"moe_align_block_size",
1321
"sdpa",
1422
"sdpa_decode_splitk",
1523
"topk",

0 commit comments

Comments
 (0)