Skip to content

Commit 40c6cc2

Browse files
committed
Route prefill MoE to batched tensor-core kernel in Qwen3.5 export
Add use_batched_moe flag on FusedMoEExperts, toggled by _set_batched_moe in export.py before each method's torch.export call. Decode (T=1) uses the vec-mat fused_moe kernel; prefill (T>=2) uses fused_moe_batched_gemm. This PR was authored with the assistance of Claude.
1 parent e285a2e commit 40c6cc2

2 files changed

Lines changed: 28 additions & 6 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,13 @@ def _apply_turboquant(model, config):
503503
# ---------------------------------------------------------------------------
504504

505505

506+
def _set_batched_moe(model, enabled):
507+
"""Toggle batched tensor-core MoE kernel for all MoE layers."""
508+
for layer in model.layers:
509+
if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"):
510+
layer.mlp.experts.use_batched_moe = enabled
511+
512+
506513
def export_and_lower(model, config, args):
507514
"""Export model to .pte via torch.export + backend-specific lowering."""
508515
backend = getattr(args, "backend", "cuda")
@@ -597,10 +604,9 @@ def _export_cuda(model, config, args):
597604
"""Export model to .pte via torch.export + CUDA backend.
598605
599606
Exports two methods:
600-
- "decode": decode path (T=1), uses native PyTorch recurrent FLA
601-
so AOTI can fuse with surrounding ops for maximum decode throughput.
602-
- "prefill": prefill path (T>=2), uses chunked FLA triton_op with
603-
dynamic sequence length.
607+
- "decode": decode path (T=1), vec-mat MoE kernel via fused_moe.
608+
- "prefill": prefill path (T>=2), batched tensor-core MoE kernel
609+
via fused_moe_batched_gemm, with dynamic sequence length.
604610
605611
Both methods share mutable state buffers (KV cache, conv_state,
606612
recurrent_state) via share_mutable_buffers=True. The model uses
@@ -625,7 +631,8 @@ def _export_cuda(model, config, args):
625631
# -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
626632
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
627633

628-
# --- Decode method (T=1, static shape) ---
634+
# --- Decode method (T=1, static shape, vec-mat MoE kernel) ---
635+
_set_batched_moe(model, False)
629636
print("Exporting decode method...")
630637
decode_tokens = torch.tensor([[0]], dtype=torch.long)
631638
decode_pos = torch.tensor([0], dtype=torch.long)
@@ -637,11 +644,12 @@ def _export_cuda(model, config, args):
637644
)
638645
print("Decode export successful!")
639646

640-
# --- Prefill method (T>=2, dynamic shape) ---
647+
# --- Prefill method (T>=2, dynamic shape, batched tensor-core MoE kernel) ---
641648
# Example T must equal max_seq_len-1 so AOTI compiles kernels (especially
642649
# chunk_gated_delta_rule with CHUNK_SIZE=64) for the full range of sequence
643650
# lengths. Smaller examples cause AOTI to bake in intermediate buffer sizes
644651
# that reject longer prompts at runtime.
652+
_set_batched_moe(model, True)
645653
print("Exporting prefill method...")
646654
example_prefill_len = config.max_seq_len - 1
647655
prefill_tokens = torch.zeros((1, example_prefill_len), dtype=torch.long)

examples/models/qwen3_5_moe/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def __init__(self, config):
492492
self.intermediate_size = config.moe_intermediate_size
493493
self.hidden_size = config.hidden_size
494494
self.group_size = 32
495+
self.use_batched_moe = False
495496

496497
self.w1_weight = nn.Parameter(
497498
torch.empty(
@@ -509,6 +510,19 @@ def __init__(self, config):
509510
)
510511

511512
def forward(self, x, expert_weights, expert_indices, top_k):
513+
if self.use_batched_moe:
514+
return torch.ops.triton.fused_moe_batched_gemm(
515+
x,
516+
self.w1,
517+
self.w1_scale,
518+
self.w2,
519+
self.w2_scale,
520+
expert_weights,
521+
expert_indices,
522+
top_k,
523+
self.num_experts,
524+
self.group_size,
525+
)
512526
return torch.ops.triton.fused_moe(
513527
x,
514528
self.w1,

0 commit comments

Comments
 (0)