Skip to content

Commit a0d199a

Browse files
committed
Route prefill MoE to batched tensor-core kernel in Qwen3.5 export
Add use_batched_moe flag on FusedMoEExperts, toggled by _set_batched_moe in export.py before each method's torch.export call. Decode (T=1) uses the vec-mat fused_moe kernel; prefill (T>=2) uses fused_moe_batched_gemm.
1 parent a204847 commit a0d199a

2 files changed

Lines changed: 33 additions & 8 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -380,14 +380,20 @@ def _apply_turboquant(model, config):
380380
# ---------------------------------------------------------------------------
381381

382382

383+
def _set_batched_moe(model, enabled):
384+
"""Toggle batched tensor-core MoE kernel for all MoE layers."""
385+
for layer in model.layers:
386+
if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"):
387+
layer.mlp.experts.use_batched_moe = enabled
388+
389+
383390
def export_and_lower(model, config, args):
384391
"""Export model to .pte via torch.export + CUDA backend.
385392
386393
Exports two methods:
387-
- "decode": decode path (T=1), uses native PyTorch recurrent FLA
388-
so AOTI can fuse with surrounding ops for maximum decode throughput.
389-
- "prefill": prefill path (T>=2), uses chunked FLA triton_op with
390-
dynamic sequence length.
394+
- "decode": decode path (T=1), vec-mat MoE kernel via fused_moe.
395+
- "prefill": prefill path (T>=2), batched tensor-core MoE kernel
396+
via fused_moe_batched_gemm, with dynamic sequence length.
391397
392398
Both methods share mutable state buffers (KV cache, conv_state,
393399
recurrent_state) via share_mutable_buffers=True. The model uses
@@ -412,7 +418,8 @@ def export_and_lower(model, config, args):
412418
# -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
413419
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
414420

415-
# --- Decode method (T=1, static shape) ---
421+
# --- Decode method (T=1, static shape, vec-mat MoE kernel) ---
422+
_set_batched_moe(model, False)
416423
print("Exporting decode method...")
417424
decode_tokens = torch.tensor([[0]], dtype=torch.long)
418425
decode_pos = torch.tensor([0], dtype=torch.long)
@@ -424,10 +431,14 @@ def export_and_lower(model, config, args):
424431
)
425432
print("Decode export successful!")
426433

427-
# --- Prefill method (T>=2, dynamic shape) ---
434+
# --- Prefill method (T>=2, dynamic shape, batched tensor-core MoE kernel) ---
435+
# Example T must equal max_seq_len-1 so AOTI compiles kernels for the
436+
# full range of sequence lengths.
437+
_set_batched_moe(model, True)
428438
print("Exporting prefill method...")
429-
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
430-
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
439+
example_seq_len = config.max_seq_len - 1
440+
prefill_tokens = torch.zeros((1, example_seq_len), dtype=torch.long)
441+
prefill_pos = torch.arange(example_seq_len, dtype=torch.long)
431442
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
432443
prefill_dynamic_shapes = (
433444
{1: seq_dim}, # tokens

examples/models/qwen3_5_moe/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def __init__(self, config):
479479
self.intermediate_size = config.moe_intermediate_size
480480
self.hidden_size = config.hidden_size
481481
self.group_size = 32
482+
self.use_batched_moe = False
482483

483484
self.w1_weight = nn.Parameter(
484485
torch.empty(
@@ -496,6 +497,19 @@ def __init__(self, config):
496497
)
497498

498499
def forward(self, x, expert_weights, expert_indices, top_k):
500+
if self.use_batched_moe:
501+
return torch.ops.triton.fused_moe_batched_gemm(
502+
x,
503+
self.w1,
504+
self.w1_scale,
505+
self.w2,
506+
self.w2_scale,
507+
expert_weights,
508+
expert_indices,
509+
top_k,
510+
self.num_experts,
511+
self.group_size,
512+
)
499513
return torch.ops.triton.fused_moe(
500514
x,
501515
self.w1,

0 commit comments

Comments
 (0)