Skip to content

Commit 66d2798

Browse files
committed
try to use KA optimize moe kernel
1 parent df2a4bf commit 66d2798

1 file changed

Lines changed: 18 additions & 10 deletions

File tree

backends/cuda/triton/kernels/fused_moe.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5),
5252
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3),
5353
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=5),
54+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=4),
55+
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3),
5456
]
5557

5658
# Autotune configs for GEMM2 (_fused_moe_silu_kernel).
@@ -63,6 +65,8 @@
6365
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=4),
6466
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=2, num_stages=3),
6567
triton.Config({"BLOCK_SIZE_N": 8, "BLOCK_SIZE_K": 256}, num_warps=4, num_stages=3),
68+
triton.Config({"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128}, num_warps=2, num_stages=3),
69+
triton.Config({"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=4),
6670
]
6771

6872

@@ -171,9 +175,9 @@ def _fused_moe_kernel(
171175
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
172176
).to(tl.float32)
173177

174-
# Dequantize and accumulate: vector-matrix multiply
175-
b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type)
176-
acc += tl.sum(a[:, None].to(compute_type) * b_dequant, axis=0)
178+
# Dequantize and accumulate in float32: vector-matrix multiply
179+
b_dequant = (b.to(tl.float32) - 8.0) * b_scale
180+
acc += tl.sum(a[:, None].to(tl.float32) * b_dequant, axis=0)
177181

178182
# Advance K pointers
179183
a_ptrs += BLOCK_SIZE_K * stride_ak
@@ -259,10 +263,10 @@ def _fused_moe_silu_kernel(
259263
k_remaining = K - k_step * BLOCK_SIZE_K
260264
k_mask = offs_k < k_remaining
261265

262-
# Load gate and up, apply SiLU(gate) * up
266+
# Load gate and up in float32, apply SiLU(gate) * up
263267
gate = tl.load(a_gate_ptrs, mask=k_mask, other=0.0).to(tl.float32)
264-
up = tl.load(a_up_ptrs, mask=k_mask, other=0.0)
265-
a = (gate * tl.sigmoid(gate) * up).to(compute_type)
268+
up = tl.load(a_up_ptrs, mask=k_mask, other=0.0).to(tl.float32)
269+
a = gate * tl.sigmoid(gate) * up
266270

267271
# Load and dequantize INT4 weights
268272
b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0)
@@ -290,8 +294,8 @@ def _fused_moe_silu_kernel(
290294
scale_ptrs, mask=k_mask[:, None] & n_mask[None, :], other=0.0
291295
).to(tl.float32)
292296

293-
b_dequant = ((b.to(tl.float32) - 8.0) * b_scale).to(compute_type)
294-
acc += tl.sum(a[:, None].to(compute_type) * b_dequant, axis=0)
297+
b_dequant = (b.to(tl.float32) - 8.0) * b_scale
298+
acc += tl.sum(a[:, None] * b_dequant, axis=0)
295299

296300
a_gate_ptrs += BLOCK_SIZE_K * stride_ak
297301
a_up_ptrs += BLOCK_SIZE_K * stride_ak
@@ -571,6 +575,8 @@ def moe_align_block_size(
571575
triton.Config(
572576
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=2
573577
),
578+
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, num_warps=8, num_stages=4),
579+
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_warps=8, num_stages=4),
574580
]
575581

576582
# Autotune configs for batched GEMM2 (down projection + SiLU).
@@ -581,6 +587,8 @@ def moe_align_block_size(
581587
triton.Config(
582588
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=2
583589
),
590+
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, num_warps=8, num_stages=3),
591+
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_warps=8, num_stages=4),
584592
]
585593

586594

@@ -778,9 +786,9 @@ def _fused_moe_silu_batched_kernel(
778786
k_remaining = K - k_step * BLOCK_SIZE_K
779787
k_mask = offs_k < k_remaining
780788

781-
# Load gate and up tiles [BLOCK_M, BLOCK_K], apply SiLU
789+
# Load gate and up in float32 [BLOCK_M, BLOCK_K], apply SiLU
782790
gate = tl.load(a_gate_ptrs, mask=k_mask[None, :], other=0.0).to(tl.float32)
783-
up = tl.load(a_up_ptrs, mask=k_mask[None, :], other=0.0)
791+
up = tl.load(a_up_ptrs, mask=k_mask[None, :], other=0.0).to(tl.float32)
784792
a = (gate * tl.sigmoid(gate) * up).to(compute_type)
785793

786794
# Load and dequantize INT4 weights [BLOCK_K, BLOCK_N]

0 commit comments

Comments
 (0)