Skip to content

Commit d936717

Browse files
committed
Hoist W4A8 activation quantization out of GEMM K-loop
Add dedicated _quantize_activations_int8_kernel and _silu_quantize_int8_kernel that pre-quantize activations to INT8 with per-row-per-tile FP32 scales before GEMM1 and GEMM2 respectively. The existing _fused_moe_batched_int8_kernel and _fused_moe_silu_batched_int8_kernel are rewritten to consume pre-quantized activations + scales, eliminating ~256 redundant tl.max reductions per program (cdiv(K, BLOCK_K) tiles * BLOCK_M rows) and halving activation HBM bandwidth in the K-loop (bf16 -> int8). BLOCK_SIZE_K is fixed at PREQUANT_BLOCK_K (= 128) so per-tile activation scales align with the GEMM K-loop. Correctness: 7/7 microbenchmark configs pass with rel diff <1.5% vs BF16 ref. End-to-end (Qwen3.5 MoE 1600 prefill + 512 decode, --cuda_graph, A100): prefill 5727 -> 6171 tok/s (+7.7%), decode 92.6 -> 99.0 tok/s (+6.9%).
1 parent 87c9947 commit d936717

5 files changed

Lines changed: 261 additions & 109 deletions

File tree

backends/cuda/tests/test_int4_matmul.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import unittest
2020

2121
import torch
22-
import torch.nn as nn
2322

2423
from executorch.backends.cuda.triton.kernels.int4_matmul import (
2524
dequant_w4_to_bf16,

backends/cuda/triton/kernels/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
fused_moe,
99
fused_moe_batched,
1010
fused_moe_batched_gemm,
11+
fused_moe_batched_gemm_int8,
1112
moe_align_block_size,
1213
)
1314

@@ -23,6 +24,8 @@
2324
"fused_moe",
2425
"fused_moe_batched",
2526
"fused_moe_batched_gemm",
27+
"fused_moe_batched_gemm_int8",
28+
"int4_matvec",
2629
"moe_align_block_size",
2730
"sdpa",
2831
"sdpa_decode_splitk",

0 commit comments

Comments
 (0)