@@ -380,14 +380,20 @@ def _apply_turboquant(model, config):
380380# ---------------------------------------------------------------------------
381381
382382
383+ def _set_batched_moe (model , enabled ):
384+ """Toggle batched tensor-core MoE kernel for all MoE layers."""
385+ for layer in model .layers :
386+ if hasattr (layer , "mlp" ) and hasattr (layer .mlp , "experts" ):
387+ layer .mlp .experts .use_batched_moe = enabled
388+
389+
383390def export_and_lower (model , config , args ):
384391 """Export model to .pte via torch.export + CUDA backend.
385392
386393 Exports two methods:
387- - "decode": decode path (T=1), uses native PyTorch recurrent FLA
388- so AOTI can fuse with surrounding ops for maximum decode throughput.
389- - "prefill": prefill path (T>=2), uses chunked FLA triton_op with
390- dynamic sequence length.
394+ - "decode": decode path (T=1), vec-mat MoE kernel via fused_moe.
395+ - "prefill": prefill path (T>=2), batched tensor-core MoE kernel
396+ via fused_moe_batched_gemm, with dynamic sequence length.
391397
392398 Both methods share mutable state buffers (KV cache, conv_state,
393399 recurrent_state) via share_mutable_buffers=True. The model uses
@@ -412,7 +418,8 @@ def export_and_lower(model, config, args):
412418 # -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
413419 inductor_config .aot_inductor .compile_wrapper_opt_level = "O0"
414420
415- # --- Decode method (T=1, static shape) ---
421+ # --- Decode method (T=1, static shape, vec-mat MoE kernel) ---
422+ _set_batched_moe (model , False )
416423 print ("Exporting decode method..." )
417424 decode_tokens = torch .tensor ([[0 ]], dtype = torch .long )
418425 decode_pos = torch .tensor ([0 ], dtype = torch .long )
@@ -424,10 +431,14 @@ def export_and_lower(model, config, args):
424431 )
425432 print ("Decode export successful!" )
426433
427- # --- Prefill method (T>=2, dynamic shape) ---
434+ # --- Prefill method (T>=2, dynamic shape, batched tensor-core MoE kernel) ---
435+ # Example T must equal max_seq_len-1 so AOTI compiles kernels for the
436+ # full range of sequence lengths.
437+ _set_batched_moe (model , True )
428438 print ("Exporting prefill method..." )
429- prefill_tokens = torch .tensor ([[0 , 1 ]], dtype = torch .long )
430- prefill_pos = torch .tensor ([0 , 1 ], dtype = torch .long )
439+ example_seq_len = config .max_seq_len - 1
440+ prefill_tokens = torch .zeros ((1 , example_seq_len ), dtype = torch .long )
441+ prefill_pos = torch .arange (example_seq_len , dtype = torch .long )
431442 seq_dim = Dim ("seq_len" , min = 2 , max = config .max_seq_len - 1 )
432443 prefill_dynamic_shapes = (
433444 {1 : seq_dim }, # tokens
0 commit comments