Skip to content

Commit 6c3bab1

Browse files
committed
further optimied for group_size == 32 and bz == 1
1 parent 21c4924 commit 6c3bab1

2 files changed

Lines changed: 84 additions & 50 deletions

File tree

backends/cuda/benchmarks/benchmark_moe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
1111
Measures latency across prompt lengths matching the Qwen3.5-35B-A3B model
1212
(hidden_size=2048, num_experts=256, top_k=8, intermediate_size=512,
13-
INT4 weight-only quantization with group_size=128).
13+
INT4 weight-only quantization with group_size=32).
1414
1515
Usage:
1616
python benchmark_moe.py
@@ -21,7 +21,6 @@
2121
from functools import partial
2222

2323
import executorch.backends.cuda.triton.kernels # noqa: F401 — registers triton ops
24-
2524
import torch
2625
from triton.testing import do_bench
2726

@@ -33,7 +32,7 @@
3332
"top_k": 8,
3433
"hidden_size": 2048,
3534
"intermediate_size": 512,
36-
"group_size": 128,
35+
"group_size": 32,
3736
}
3837

3938
PROMPT_LENGTHS = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4095]

backends/cuda/triton/kernels/fused_moe.py

Lines changed: 82 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -42,31 +42,39 @@
4242

4343

4444
# Autotune configs for GEMM1 (_fused_moe_kernel).
45-
# Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions
46-
# (M=1, N=1024, K=2048, 8 experts, group_size=128).
45+
# Qwen3.5 MoE dimensions (M=1, N=1024, K=2048, 8 experts, group_size=32).
46+
# BLOCK_K ≤ 32 ensures the efficient one-scale-per-tile path.
4747
_GEMM1_CONFIGS = [
48+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=2),
49+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=4),
50+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=5),
51+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=3),
52+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=4),
53+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=3),
54+
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=3),
55+
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=4),
4856
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2),
49-
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=2),
50-
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=4),
51-
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5),
52-
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3),
53-
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5),
54-
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=4),
55-
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3),
57+
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=3),
58+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 16}, num_warps=2, num_stages=4),
59+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 16}, num_warps=2, num_stages=5),
5660
]
5761

5862
# Autotune configs for GEMM2 (_fused_moe_silu_kernel).
59-
# Top performers from CI benchmark on A100-SXM4-80GB, Qwen3.5 MoE dimensions
60-
# (M=1, N=2048, K=512, 8 experts, group_size=128).
63+
# Qwen3.5 MoE dimensions (M=1, N=2048, K=512, 8 experts, group_size=32).
64+
# BLOCK_K ≤ 32 ensures the efficient one-scale-per-tile path.
6165
_GEMM2_CONFIGS = [
66+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=2),
67+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=4),
68+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=5),
69+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=3),
70+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=4),
71+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=3),
72+
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=3),
73+
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=2, num_stages=4),
6274
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=2),
63-
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 128}, num_warps=2, num_stages=4),
64-
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=4),
65-
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=4),
66-
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3),
67-
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=3),
68-
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128}, num_warps=2, num_stages=3),
69-
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=4),
75+
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=3),
76+
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 16}, num_warps=2, num_stages=4),
77+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 16}, num_warps=2, num_stages=5),
7078
]
7179

7280

@@ -76,7 +84,7 @@ def _fused_moe_kernel(
7684
# Pointers
7785
A, # [M, K] bf16 activations
7886
B, # [E, N, K//2] int8 packed INT4 weights
79-
C, # [M * top_k, N] bf16 output
87+
C, # [M, N] fp32 output (atomic accumulation across experts)
8088
B_scale, # [E, N, K//group_size] bf16 scales
8189
topk_ids, # [M * top_k] int64 expert indices
8290
topk_weights, # [M * top_k] float32 router weights
@@ -144,11 +152,16 @@ def _fused_moe_kernel(
144152
k_remaining = K - k_step * BLOCK_SIZE_K
145153
k_mask = offs_k < k_remaining
146154

147-
# Load A tile [BLOCK_SIZE_K]
148-
a = tl.load(a_ptrs, mask=k_mask, other=0.0)
155+
# Load A tile [BLOCK_SIZE_K] — reused across N-blocks, keep in L2
156+
a = tl.load(a_ptrs, mask=k_mask, other=0.0, eviction_policy="evict_last")
149157

150-
# Load B tile [BLOCK_SIZE_K, BLOCK_SIZE_N] and unpack INT4
151-
b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0)
158+
# Load B tile [BLOCK_SIZE_K, BLOCK_SIZE_N] and unpack INT4 — streaming
159+
b = tl.load(
160+
b_ptrs,
161+
mask=k_mask[:, None] & n_mask[None, :],
162+
other=0,
163+
eviction_policy="evict_first",
164+
)
152165
b = (b >> b_shifter) & 0xF
153166

154167
# Load per-group scales and dequantize
@@ -161,9 +174,12 @@ def _fused_moe_kernel(
161174
+ offs_n[None, :] * stride_bsn
162175
+ group_idx * stride_bsk
163176
)
164-
b_scale = tl.load(scale_ptrs, mask=n_mask[None, :], other=0.0).to(
165-
tl.float32
166-
)
177+
b_scale = tl.load(
178+
scale_ptrs,
179+
mask=n_mask[None, :],
180+
other=0.0,
181+
eviction_policy="evict_first",
182+
).to(tl.float32)
167183
else:
168184
scale_ptrs = (
169185
B_scale
@@ -172,7 +188,10 @@ def _fused_moe_kernel(
172188
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
173189
)
174190
b_scale = tl.load(
175-
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
191+
scale_ptrs,
192+
mask=k_mask[:, None] & n_mask[None, :],
193+
other=0.0,
194+
eviction_policy="evict_first",
176195
).to(tl.float32)
177196

178197
# Dequantize and accumulate in float32: vector-matrix multiply
@@ -222,6 +241,7 @@ def _fused_moe_silu_kernel(
222241
group_size: tl.constexpr,
223242
BLOCK_SIZE_N: tl.constexpr,
224243
BLOCK_SIZE_K: tl.constexpr,
244+
top_k: tl.constexpr,
225245
compute_type: tl.constexpr,
226246
):
227247
"""GEMM2 with fused SiLU activation.
@@ -263,13 +283,22 @@ def _fused_moe_silu_kernel(
263283
k_remaining = K - k_step * BLOCK_SIZE_K
264284
k_mask = offs_k < k_remaining
265285

266-
# Load gate and up in float32, apply SiLU(gate) * up
267-
gate = tl.load(a_gate_ptrs, mask=k_mask, other=0.0).to(tl.float32)
268-
up = tl.load(a_up_ptrs, mask=k_mask, other=0.0).to(tl.float32)
286+
# Load gate and up in float32, apply SiLU(gate) * up — reused across N-blocks
287+
gate = tl.load(
288+
a_gate_ptrs, mask=k_mask, other=0.0, eviction_policy="evict_last"
289+
).to(tl.float32)
290+
up = tl.load(
291+
a_up_ptrs, mask=k_mask, other=0.0, eviction_policy="evict_last"
292+
).to(tl.float32)
269293
a = gate * tl.sigmoid(gate) * up
270294

271-
# Load and dequantize INT4 weights
272-
b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0)
295+
# Load and dequantize INT4 weights — streaming
296+
b = tl.load(
297+
b_ptrs,
298+
mask=k_mask[:, None] & n_mask[None, :],
299+
other=0,
300+
eviction_policy="evict_first",
301+
)
273302
b = (b >> b_shifter) & 0xF
274303

275304
if BLOCK_SIZE_K <= group_size:
@@ -280,9 +309,12 @@ def _fused_moe_silu_kernel(
280309
+ offs_n[None, :] * stride_bsn
281310
+ group_idx * stride_bsk
282311
)
283-
b_scale = tl.load(scale_ptrs, mask=n_mask[None, :], other=0.0).to(
284-
tl.float32
285-
)
312+
b_scale = tl.load(
313+
scale_ptrs,
314+
mask=n_mask[None, :],
315+
other=0.0,
316+
eviction_policy="evict_first",
317+
).to(tl.float32)
286318
else:
287319
scale_ptrs = (
288320
B_scale
@@ -291,7 +323,10 @@ def _fused_moe_silu_kernel(
291323
+ ((offs_k[:, None] + BLOCK_SIZE_K * k_step) // group_size) * stride_bsk
292324
)
293325
b_scale = tl.load(
294-
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
326+
scale_ptrs,
327+
mask=k_mask[:, None] & n_mask[None, :],
328+
other=0.0,
329+
eviction_policy="evict_first",
295330
).to(tl.float32)
296331

297332
b_dequant = (b.to(tl.float32) - 8.0) * b_scale
@@ -301,12 +336,13 @@ def _fused_moe_silu_kernel(
301336
a_up_ptrs += BLOCK_SIZE_K * stride_ak
302337
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
303338

304-
# Multiply by router weight
339+
# Multiply by router weight and atomically accumulate into token row
305340
weight = tl.load(topk_weights + pair_idx)
306341
acc = acc * weight
307342

308-
c_ptrs = C + pair_idx * stride_cm + offs_n * stride_cn
309-
tl.store(c_ptrs, acc.to(compute_type), mask=n_mask)
343+
token_idx = pair_idx // top_k
344+
c_ptrs = C + token_idx * stride_cm + offs_n * stride_cn
345+
tl.atomic_add(c_ptrs, acc, mask=n_mask)
310346

311347

312348
# ---------------------------------------------------------------------------
@@ -394,17 +430,16 @@ def grid1(meta):
394430
)
395431

396432
# ---- GEMM2 with fused SiLU: reads gate+up from cache1, no intermediate buffer ----
397-
cache3 = torch.empty(
398-
num_pairs, N2, dtype=hidden_states.dtype, device=hidden_states.device
399-
)
433+
# Zero-init FP32 buffer — atomic_add in the kernel accumulates across top_k experts
434+
output = torch.zeros(M, N2, dtype=torch.float32, device=hidden_states.device)
400435

401436
def grid2(meta):
402437
return (num_pairs * triton.cdiv(N2, meta["BLOCK_SIZE_N"]),)
403438

404439
wrap_triton(_fused_moe_silu_kernel)[grid2](
405440
cache1,
406441
w2,
407-
cache3,
442+
output,
408443
w2_scale,
409444
topk_ids_flat,
410445
topk_weights_flat,
@@ -416,17 +451,17 @@ def grid2(meta):
416451
stride_be=w2.stride(0),
417452
stride_bk=w2.stride(2),
418453
stride_bn=w2.stride(1),
419-
stride_cm=cache3.stride(0),
420-
stride_cn=cache3.stride(1),
454+
stride_cm=output.stride(0),
455+
stride_cn=output.stride(1),
421456
stride_bse=w2_scale.stride(0),
422457
stride_bsk=w2_scale.stride(2),
423458
stride_bsn=w2_scale.stride(1),
424459
group_size=group_size,
460+
top_k=top_k,
425461
compute_type=tl.bfloat16,
426462
)
427463

428-
# ---- Sum across top-k experts ----
429-
return cache3.view(M, top_k, N2).sum(dim=1)
464+
return output.to(hidden_states.dtype)
430465

431466

432467
@fused_moe.register_fake

0 commit comments

Comments
 (0)