Skip to content

Commit e06db27

Browse files
committed
Use max example seq_len when exporting Qwen3.5 MoE
The previous example used T=2, which caused AOTI to compile the chunk_gated_delta_rule kernel for a single chunk (NT=1). At runtime, prompts longer than 64 tokens (requiring NT>1 chunks) failed with "Error resizing tensor at input 0". Using max_seq_len-1 as the example ensures AOTI generalizes intermediate buffer sizes for the full sequence length range. Comparison against original export (tq4_sdpa fused kernel) on H100 (Qwen3.5-35B-A3B, HQQ-INT4, max_seq_len=4096, 5 runs median): Original (tq4_sdpa) Baseline (Triton SDPA) Decode tok/s 68.4 61.7 Prefill tok/s 275.7 378.2 Baseline prefill is 1.37x faster; decode is 0.90x (tq4_sdpa's fused decode kernel is faster than the tiled Triton SDPA at L_q=1). The split-K commit addresses the decode gap.
1 parent 82641e8 commit e06db27

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,13 @@ def export_and_lower(model, config, args):
398398
# -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
399399
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
400400

401-
# Dynamic shapes
402-
example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
403-
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
401+
# Dynamic shapes — example T must equal max_seq_len-1 so AOTI compiles
402+
# kernels (especially chunk_gated_delta_rule with CHUNK_SIZE=64) for the
403+
# full range of sequence lengths. Smaller examples cause AOTI to bake in
404+
# intermediate buffer sizes that reject longer prompts at runtime.
405+
example_seq_len = config.max_seq_len - 1
406+
example_tokens = torch.zeros((1, example_seq_len), dtype=torch.long)
407+
example_input_pos = torch.arange(example_seq_len, dtype=torch.long)
404408
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
405409
dynamic_shapes = ({1: seq_dim}, {0: seq_dim})
406410

0 commit comments

Comments
 (0)