Skip to content

Commit ecca812

Browse files
committed
Use max example seq_len when exporting Qwen3.5 MoE
The previous example used T=2, which caused AOTI to compile the chunk_gated_delta_rule kernel for a single chunk (NT=1). At runtime, prompts longer than 64 tokens (requiring NT>1 chunks) failed with "Error resizing tensor at input 0". Using max_seq_len-1 as the example ensures AOTI generalizes intermediate buffer sizes for the full sequence length range. Comparison against original export (tq4_sdpa fused kernel) on H100 (Qwen3.5-35B-A3B, HQQ-INT4, max_seq_len=4096, 5 runs median): Original (tq4_sdpa) Baseline (Triton SDPA) Decode tok/s 68.4 61.7 Prefill tok/s 275.7 378.2 Baseline prefill is 1.37x faster; decode is 0.90x (tq4_sdpa's fused decode kernel is faster than the tiled Triton SDPA at L_q=1). The split-K commit addresses the decode gap. This PR was authored with the assistance of Claude
1 parent 40e361b commit ecca812

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,9 +626,14 @@ def _export_cuda(model, config, args):
626626
print("Decode export successful!")
627627

628628
# --- Prefill method (T>=2, dynamic shape) ---
629+
# Example T must equal max_seq_len-1 so AOTI compiles kernels (especially
630+
# chunk_gated_delta_rule with CHUNK_SIZE=64) for the full range of sequence
631+
# lengths. Smaller examples cause AOTI to bake in intermediate buffer sizes
632+
# that reject longer prompts at runtime.
629633
print("Exporting prefill method...")
630-
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
631-
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
634+
example_prefill_len = config.max_seq_len - 1
635+
prefill_tokens = torch.zeros((1, example_prefill_len), dtype=torch.long)
636+
prefill_pos = torch.arange(example_prefill_len, dtype=torch.long)
632637
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
633638
prefill_dynamic_shapes = (
634639
{1: seq_dim}, # tokens

0 commit comments

Comments
 (0)