@@ -503,6 +503,13 @@ def _apply_turboquant(model, config):
503503# ---------------------------------------------------------------------------
504504
505505
506+ def _set_batched_moe (model , enabled ):
507+ """Toggle batched tensor-core MoE kernel for all MoE layers."""
508+ for layer in model .layers :
509+ if hasattr (layer , "mlp" ) and hasattr (layer .mlp , "experts" ):
510+ layer .mlp .experts .use_batched_moe = enabled
511+
512+
506513def export_and_lower (model , config , args ):
507514 """Export model to .pte via torch.export + backend-specific lowering."""
508515 backend = getattr (args , "backend" , "cuda" )
@@ -597,10 +604,9 @@ def _export_cuda(model, config, args):
597604 """Export model to .pte via torch.export + CUDA backend.
598605
599606 Exports two methods:
600- - "decode": decode path (T=1), uses native PyTorch recurrent FLA
601- so AOTI can fuse with surrounding ops for maximum decode throughput.
602- - "prefill": prefill path (T>=2), uses chunked FLA triton_op with
603- dynamic sequence length.
607+ - "decode": decode path (T=1), vec-mat MoE kernel via fused_moe.
608+ - "prefill": prefill path (T>=2), batched tensor-core MoE kernel
609+ via fused_moe_batched_gemm, with dynamic sequence length.
604610
605611 Both methods share mutable state buffers (KV cache, conv_state,
606612 recurrent_state) via share_mutable_buffers=True. The model uses
@@ -625,7 +631,8 @@ def _export_cuda(model, config, args):
625631 # -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
626632 inductor_config .aot_inductor .compile_wrapper_opt_level = "O0"
627633
628- # --- Decode method (T=1, static shape) ---
634+ # --- Decode method (T=1, static shape, vec-mat MoE kernel) ---
635+ _set_batched_moe (model , False )
629636 print ("Exporting decode method..." )
630637 decode_tokens = torch .tensor ([[0 ]], dtype = torch .long )
631638 decode_pos = torch .tensor ([0 ], dtype = torch .long )
@@ -637,11 +644,12 @@ def _export_cuda(model, config, args):
637644 )
638645 print ("Decode export successful!" )
639646
640- # --- Prefill method (T>=2, dynamic shape) ---
647+ # --- Prefill method (T>=2, dynamic shape, batched tensor-core MoE kernel ) ---
641648 # Example T must equal max_seq_len-1 so AOTI compiles kernels (especially
642649 # chunk_gated_delta_rule with CHUNK_SIZE=64) for the full range of sequence
643650 # lengths. Smaller examples cause AOTI to bake in intermediate buffer sizes
644651 # that reject longer prompts at runtime.
652+ _set_batched_moe (model , True )
645653 print ("Exporting prefill method..." )
646654 example_prefill_len = config .max_seq_len - 1
647655 prefill_tokens = torch .zeros ((1 , example_prefill_len ), dtype = torch .long )
0 commit comments