Skip to content

Commit 32c49a3

Browse files
committed
Add W4A8 INT8 activation kernels for batched MoE prefill
INT8 tensor core variants of the batched MoE GEMM kernels that dynamically quantize bf16 activations to INT8 per-row per-tile and dequantize INT4 weights directly to INT8 (skipping bf16 conversion). Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine similarity vs bf16 baseline. Co-authored-by: Claude <noreplyanthropic.com> ghstack-source-id: a153b52 Pull Request resolved: #19187
1 parent cb4e5ae commit 32c49a3

4 files changed

Lines changed: 501 additions & 4 deletions

File tree

backends/cuda/benchmarks/benchmark_moe.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,34 @@ def _run_triton_batched(
247247
)
248248

249249
BACKENDS["triton_batched"] = ("Triton batched", _run_triton_batched)
250+
251+
def _run_triton_batched_int8(
252+
hidden_states,
253+
w1,
254+
w1_scale,
255+
w2,
256+
w2_scale,
257+
topk_weights,
258+
topk_ids,
259+
top_k,
260+
num_experts,
261+
group_size,
262+
):
263+
return fused_moe_batched(
264+
hidden_states,
265+
w1,
266+
w1_scale,
267+
w2,
268+
w2_scale,
269+
topk_weights,
270+
topk_ids,
271+
top_k=top_k,
272+
num_experts=num_experts,
273+
group_size=group_size,
274+
activation_dtype="int8",
275+
)
276+
277+
BACKENDS["triton_batched_int8"] = ("Triton bat-i8", _run_triton_batched_int8)
250278
except ImportError:
251279
pass
252280

@@ -358,6 +386,15 @@ def run_benchmark(
358386
f"Triton vs eager mismatch at M={M}: "
359387
f"max abs error {err:.3e} >= 2.0e-1"
360388
)
389+
if "triton_batched_int8" in BACKENDS:
390+
_, _, run_int8 = BACKENDS["triton_batched_int8"]
391+
int8_out = run_int8(**common_args)
392+
int8_err = _max_abs_error(int8_out, ref_out)
393+
assert int8_err < 5.0e-1, (
394+
f"Triton INT8 vs eager mismatch at M={M}: "
395+
f"max abs error {int8_err:.3e} >= 5.0e-1"
396+
)
397+
del int8_out
361398
del ref_out, tri_out
362399

363400
# Benchmark

0 commit comments

Comments
 (0)