Skip to content

Commit a6aba5e

Browse files
committed
Update on "Add W4A8 INT8 activation kernels for batched MoE prefill"
INT8 tensor core variants of the batched MoE GEMM kernels that dynamically quantize bf16 activations to INT8 per-row per-tile and dequantize INT4 weights directly to INT8 (skipping bf16 conversion). Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine similarity vs bf16 baseline. Co-authored-by: Claude <noreplyanthropic.com> [ghstack-poisoned]
1 parent 0a61d6d commit a6aba5e

1 file changed

Lines changed: 14 additions & 8 deletions

File tree

backends/cuda/triton/kernels/fused_moe.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -707,11 +707,11 @@ def _fused_moe_batched_kernel(
707707
# Autotune configs for batched INT8 GEMM1 (gate+up projection, W4A8).
708708
_BATCHED_GEMM1_INT8_CONFIGS = [
709709
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=3),
710-
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=2),
711-
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3),
712710
triton.Config(
713-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3
711+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=2
714712
),
713+
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3),
714+
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3),
715715
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=4),
716716
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=4),
717717
]
@@ -833,7 +833,10 @@ def _fused_moe_batched_int8_kernel(
833833
else:
834834
# Multi-group tile: dequantize weights per group, use float matmul
835835
b_dequant = (b_int8.to(tl.float32) * b_scale).to(compute_type)
836-
acc += tl.dot(a_int8.to(compute_type), b_dequant).to(tl.float32) * a_scale[:, None]
836+
acc += (
837+
tl.dot(a_int8.to(compute_type), b_dequant).to(tl.float32)
838+
* a_scale[:, None]
839+
)
837840

838841
a_ptrs += BLOCK_SIZE_K * stride_ak
839842
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
@@ -977,11 +980,11 @@ def _fused_moe_silu_batched_kernel(
977980
# Autotune configs for batched INT8 GEMM2 (down projection + SiLU, W4A8).
978981
_BATCHED_GEMM2_INT8_CONFIGS = [
979982
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=2),
980-
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=2),
981-
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3),
982983
triton.Config(
983-
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3
984+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}, num_warps=4, num_stages=2
984985
),
986+
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3),
987+
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_warps=4, num_stages=3),
985988
triton.Config({"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=4),
986989
triton.Config({"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_warps=4, num_stages=4),
987990
]
@@ -1105,7 +1108,10 @@ def _fused_moe_silu_batched_int8_kernel(
11051108
else:
11061109
# Multi-group tile: dequantize weights per group, use float matmul
11071110
b_dequant = (b_int8.to(tl.float32) * b_scale).to(compute_type)
1108-
acc += tl.dot(a_int8.to(compute_type), b_dequant).to(tl.float32) * a_scale[:, None]
1111+
acc += (
1112+
tl.dot(a_int8.to(compute_type), b_dequant).to(tl.float32)
1113+
* a_scale[:, None]
1114+
)
11091115

11101116
a_gate_ptrs += BLOCK_SIZE_K * stride_ak
11111117
a_up_ptrs += BLOCK_SIZE_K * stride_ak

0 commit comments

Comments
 (0)