Skip to content

Commit a05c2d1

Browse files
committed
Use max example seq_len when exporting Qwen3.5 MoE
The previous example used T=2, which caused AOTI to compile the chunk_gated_delta_rule kernel for a single chunk (NT=1). At runtime, prompts longer than 64 tokens (requiring NT>1 chunks) failed with "Error resizing tensor at input 0". Using max_seq_len-1 as the example ensures AOTI generalizes intermediate buffer sizes for the full sequence length range. Comparison against original export (tq4_sdpa fused kernel) on H100 (Qwen3.5-35B-A3B, HQQ-INT4, max_seq_len=4096, 5 runs median): Original (tq4_sdpa) Baseline (Triton SDPA) Decode tok/s 68.4 61.7 Prefill tok/s 275.7 378.2 Baseline prefill is 1.37x faster; decode is 0.90x (tq4_sdpa's fused decode kernel is faster than the tiled Triton SDPA at L_q=1). The split-K commit addresses the decode gap. This PR was authored with the assistance of Claude
1 parent 8c9868a commit a05c2d1

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,9 +425,14 @@ def export_and_lower(model, config, args):
425425
print("Decode export successful!")
426426

427427
# --- Prefill method (T>=2, dynamic shape) ---
428+
# Example T must equal max_seq_len-1 so AOTI compiles kernels (especially
429+
# chunk_gated_delta_rule with CHUNK_SIZE=64) for the full range of sequence
430+
# lengths. Smaller examples cause AOTI to bake in intermediate buffer sizes
431+
# that reject longer prompts at runtime.
428432
print("Exporting prefill method...")
429-
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
430-
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
433+
example_prefill_len = config.max_seq_len - 1
434+
prefill_tokens = torch.zeros((1, example_prefill_len), dtype=torch.long)
435+
prefill_pos = torch.arange(example_prefill_len, dtype=torch.long)
431436
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
432437
prefill_dynamic_shapes = (
433438
{1: seq_dim}, # tokens

0 commit comments

Comments
 (0)