Skip to content

Commit dbcc10f

Browse files
committed
Update on "Add W4A8 INT8 activation kernels for batched MoE prefill"
INT8 tensor core variants of the batched MoE GEMM kernels that dynamically quantize bf16 activations to INT8 per-row per-tile and dequantize INT4 weights directly to INT8 (skipping bf16 conversion). Uses tl.dot(int8, int8) → int32 accumulation with per-tile float32 rescale. 1.7× MoE speedup on A100 at M=1024 with 0.9998 cosine similarity vs bf16 baseline. Co-authored-by: Claude <noreplyanthropic.com> [ghstack-poisoned]
1 parent 594009d commit dbcc10f

3 files changed

Lines changed: 13 additions & 8 deletions

File tree

backends/cuda/tests/test_fused_moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,11 @@ def _run_cpp_runner(runner_path, pte_path, ptd_path, input_files, output_base):
213213

214214

215215
class TestFusedMoE(unittest.TestCase):
216+
# TODO: migrate from manual max_abs/max_ref relative checks to
217+
# torch.allclose(atol=, rtol=). Current tests use per-tensor-max relative
218+
# error which is looser than per-element allclose — need to calibrate atol
219+
# for INT4 quantization noise floor across random weight magnitudes.
220+
216221
def setUp(self):
217222
if not torch.cuda.is_available():
218223
self.skipTest("CUDA is not available")

examples/models/qwen3_5_moe/export.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -535,12 +535,12 @@ def _apply_turboquant(model, config):
535535
# ---------------------------------------------------------------------------
536536

537537

538-
def _set_batched_moe(model, enabled, activation_dtype="bf16"):
538+
def _set_batched_moe(model, enabled, moe_activation_dtype="bf16"):
539539
"""Toggle batched tensor-core MoE kernel for all MoE layers."""
540540
for layer in model.layers:
541541
if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"):
542542
layer.mlp.experts.use_batched_moe = enabled
543-
layer.mlp.experts.activation_dtype = activation_dtype
543+
layer.mlp.experts.moe_activation_dtype = moe_activation_dtype
544544

545545

546546
def export_and_lower(model, config, args):
@@ -783,8 +783,8 @@ def _export_cuda(model, config, args):
783783
# chunk_gated_delta_rule with CHUNK_SIZE=64) for the full range of sequence
784784
# lengths. Smaller examples cause AOTI to bake in intermediate buffer sizes
785785
# that reject longer prompts at runtime.
786-
activation_dtype = getattr(args, "activation_dtype", "bf16")
787-
_set_batched_moe(model, True, activation_dtype=activation_dtype)
786+
moe_activation_dtype = getattr(args, "moe_activation_dtype", "bf16")
787+
_set_batched_moe(model, True, moe_activation_dtype=moe_activation_dtype)
788788
print("Exporting prefill method...")
789789

790790
example_prefill_len = config.max_seq_len - 1
@@ -949,10 +949,10 @@ def main(): # noqa: C901
949949
help="Disable split-K (flash-decoding) SDPA for decode; use tiled SDPA instead.",
950950
)
951951
parser.add_argument(
952-
"--activation-dtype",
952+
"--moe-activation-dtype",
953953
choices=["bf16", "int8"],
954954
default="bf16",
955-
help="Activation dtype for batched MoE prefill kernels (bf16=W4A16, int8=W4A8).",
955+
help="MoE activation dtype for prefill only. Decode always uses bf16. bf16 (default): W4A16 batched GEMM. int8: W4A8 with INT8 tensor cores (~1.5x faster prefill).",
956956
)
957957
args = parser.parse_args()
958958

examples/models/qwen3_5_moe/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def __init__(self, config):
479479
self.hidden_size = config.hidden_size
480480
self.group_size = 32
481481
self.use_batched_moe = False
482-
self.activation_dtype = "bf16"
482+
self.moe_activation_dtype = "bf16"
483483

484484
self.w1_weight = nn.Parameter(
485485
torch.empty(
@@ -498,7 +498,7 @@ def __init__(self, config):
498498

499499
def forward(self, x, expert_weights, expert_indices, top_k):
500500
if self.use_batched_moe:
501-
if self.activation_dtype == "int8":
501+
if self.moe_activation_dtype == "int8":
502502
return torch.ops.triton.fused_moe_batched_gemm_int8(
503503
x,
504504
self.w1,

0 commit comments

Comments
 (0)