Skip to content

Commit 217ad45

Browse files
authored
AutoTune MoE kernel block sizes for accelerate inference (#18551)
## Summary This PR introduces Triton autotuning for MoE kernels, improving Qwen3.5 MoE model inference from **66.8 token/s → 77.7 token/s**. ## Motivation Profiling the Qwen3.5 MoE model (prior to GQA/MQA support in Triton SDPA) shows MoE dominates GPU time: | Category | Total (ms) | % GPU | |---|---|---| | **MoE** | **1,420** | **54.7%** | | Triton fused ops | 433 | 16.7% | | SDPA | 288 | 11.1% | | int4mm | 240 | 9.2% | | chunk_gated_delta_rule | 151 | 5.8% | | Router | 65 | 2.5% | The `fused_moe` kernel is the single largest bottleneck, making it the highest-leverage optimization target. ## Approach Due to hardware constraints, exhaustive autotuning at `aoti-compile` time is impractical. Instead, we: 1. **Benchmarked** all hyperparameter combinations for MoE kernels on an A100 server ([full results](https://gist.github.com/Gasoonjia/baae2475684d1246c82865ff5cbd949d)) 2. **Selected** the top-5 configurations plus the original `(N=32, K=32)` baseline 3. **Registered** them as `@triton.autotune` configs for the MoE kernels ## Results — MoE Kernel | Kernel | Best Config | Baseline | Best | Improvement | |---|---|---|---|---| | GEMM1 | `(8, 256, w2, s2)` | 60.4 µs | 32.8 µs | **45.8% faster** | | GEMM2 | `(8, 128, w2, s4)` | 29.2 µs | 26.1 µs | **10.6% faster** | **MoE kernel overall: 89.6 µs → 58.9 µs (34.3% improvement)** ## Results — End-to-End Inference | | Token/s | |---|---| | Baseline | 66.8 | | With this PR | **77.7** |
1 parent 186eb4b commit 217ad45

2 files changed

Lines changed: 37 additions & 9 deletions

File tree

.ci/scripts/export_model_artifact.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then
424424
test -f "${OUTPUT_DIR}/model.pte"
425425
test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd"
426426
ls -al "${OUTPUT_DIR}"
427+
427428
exit 0
428429
fi
429430

backends/cuda/triton/kernels/fused_moe.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,32 @@
3838
from torch.library import triton_op, wrap_triton
3939

4040

41+
# Autotune configs for GEMM1 (_fused_moe_kernel).
42+
# Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions
43+
# (M=1, N=1024, K=2048, 8 experts, group_size=128).
44+
_GEMM1_CONFIGS = [
45+
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2),
46+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=2),
47+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=4),
48+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5),
49+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3),
50+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5),
51+
]
52+
53+
# Autotune configs for GEMM2 (_fused_moe_silu_kernel).
54+
# Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions
55+
# (M=1, N=2048, K=512, 8 experts, group_size=128).
56+
_GEMM2_CONFIGS = [
57+
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2),
58+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 128}, num_warps=2, num_stages=4),
59+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=4),
60+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=4),
61+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3),
62+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=3),
63+
]
64+
65+
66+
@triton.autotune(configs=_GEMM1_CONFIGS, key=["N", "K"])
4167
@triton.jit
4268
def _fused_moe_kernel(
4369
# Pointers
@@ -147,6 +173,7 @@ def _fused_moe_kernel(
147173
tl.store(c_ptrs, acc.to(compute_type), mask=n_mask)
148174

149175

176+
@triton.autotune(configs=_GEMM2_CONFIGS, key=["N", "K"])
150177
@triton.jit
151178
def _fused_moe_silu_kernel(
152179
# Pointers
@@ -294,18 +321,19 @@ def fused_moe(
294321
N2 = w2.shape[1] # hidden_size
295322
num_pairs = M * top_k
296323

297-
BLOCK_SIZE_N = 32
298-
BLOCK_SIZE_K = 32
299-
300324
# Flatten topk tensors
301325
topk_ids_flat = topk_ids.reshape(-1)
302326
topk_weights_flat = topk_weights.reshape(-1)
303327

304328
# ---- GEMM1: gate + up projection ----
329+
# Grid is a lambda because BLOCK_SIZE_N is selected by autotune
305330
cache1 = torch.empty(
306331
num_pairs, N1, dtype=hidden_states.dtype, device=hidden_states.device
307332
)
308-
grid1 = (num_pairs * triton.cdiv(N1, BLOCK_SIZE_N),)
333+
334+
def grid1(meta):
335+
return (num_pairs * triton.cdiv(N1, meta["BLOCK_SIZE_N"]),)
336+
309337
wrap_triton(_fused_moe_kernel)[grid1](
310338
hidden_states,
311339
w1,
@@ -327,8 +355,6 @@ def fused_moe(
327355
stride_bsk=w1_scale.stride(2),
328356
stride_bsn=w1_scale.stride(1),
329357
group_size=group_size,
330-
BLOCK_SIZE_N=BLOCK_SIZE_N,
331-
BLOCK_SIZE_K=BLOCK_SIZE_K,
332358
MUL_ROUTED_WEIGHT=False,
333359
top_k=top_k,
334360
compute_type=tl.bfloat16,
@@ -338,7 +364,10 @@ def fused_moe(
338364
cache3 = torch.empty(
339365
num_pairs, N2, dtype=hidden_states.dtype, device=hidden_states.device
340366
)
341-
grid2 = (num_pairs * triton.cdiv(N2, BLOCK_SIZE_N),)
367+
368+
def grid2(meta):
369+
return (num_pairs * triton.cdiv(N2, meta["BLOCK_SIZE_N"]),)
370+
342371
wrap_triton(_fused_moe_silu_kernel)[grid2](
343372
cache1,
344373
w2,
@@ -360,8 +389,6 @@ def fused_moe(
360389
stride_bsk=w2_scale.stride(2),
361390
stride_bsn=w2_scale.stride(1),
362391
group_size=group_size,
363-
BLOCK_SIZE_N=BLOCK_SIZE_N,
364-
BLOCK_SIZE_K=BLOCK_SIZE_K,
365392
compute_type=tl.bfloat16,
366393
)
367394

0 commit comments

Comments
 (0)