Skip to content

Commit 21c2f4b

Browse files
committed
Update on "Add Triton INT4 dense kernels with dequant prefill path for Qwen3.5 MoE"
Add three new Triton kernels for dense W4A16 linear projections that replace tinygemm's tiled INT4 format with simple [N, K//2] packed weights (same format as MoE experts): - int4_matmul: fused dequant+tl.dot GEMM for medium M (prefill crossover) - int4_matvec: bandwidth-optimized vec-mat for M=1 decode - dequant_w4_to_bf16: weight dequant for large-M prefill via Inductor mm W4DequantLinear wraps these with dual decode/prefill dispatch: - Decode (M=1): int4_matvec (73 tok/s, ~35% slower than tinygemm) - Prefill (M>1): dequant+F.linear via Inductor (3400 tok/s at 3K tokens, +67% over tinygemm baseline) Single 18GB weight blob (no duplication). Decode perf regression is a known trade-off for uniform weight format — to be revisited with a CUDA C++ matvec kernel. Also adds INT8 dynamic-activation MoE tests and comprehensive correctness tests (48 tests, all passing at rtol=0.01). Co-authored-by: Claude <noreplyanthropic.com> [ghstack-poisoned]
2 parents 3e518f0 + eb03574 commit 21c2f4b

3 files changed

Lines changed: 14 additions & 11 deletions

File tree

.ci/scripts/export_model_artifact.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,9 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then
418418
TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \
419419
python -m executorch.examples.models.qwen3_5_moe.export \
420420
--prequantized "$LOCAL_MODEL_DIR" \
421-
--output-dir "${OUTPUT_DIR}"
421+
--output-dir "${OUTPUT_DIR}" \
422+
--dense-prefill dequant \
423+
--moe-activation-dtype int8
422424
echo "::endgroup::"
423425

424426
test -f "${OUTPUT_DIR}/model.pte"

examples/models/qwen3_5_moe/export.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -667,12 +667,12 @@ def _apply_turboquant(model, config):
667667
# ---------------------------------------------------------------------------
668668

669669

670-
def _set_batched_moe(model, enabled, moe_moe_moe_activation_dtype="bf16"):
670+
def _set_batched_moe(model, enabled, moe_activation_dtype="bf16"):
671671
"""Toggle batched tensor-core MoE kernel for all MoE layers."""
672672
for layer in model.layers:
673673
if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"):
674674
layer.mlp.experts.use_batched_moe = enabled
675-
layer.mlp.experts.moe_moe_moe_activation_dtype = moe_moe_moe_activation_dtype
675+
layer.mlp.experts.moe_activation_dtype = moe_activation_dtype
676676

677677

678678
def export_and_lower(model, config, args):
@@ -916,8 +916,8 @@ def _export_cuda(model, config, args):
916916
# chunk_gated_delta_rule with CHUNK_SIZE=64) for the full range of sequence
917917
# lengths. Smaller examples cause AOTI to bake in intermediate buffer sizes
918918
# that reject longer prompts at runtime.
919-
moe_moe_moe_activation_dtype = getattr(args, "moe_moe_moe_activation_dtype", "bf16")
920-
_set_batched_moe(model, True, moe_moe_moe_activation_dtype=moe_moe_moe_activation_dtype)
919+
moe_activation_dtype = getattr(args, "moe_activation_dtype", "bf16")
920+
_set_batched_moe(model, True, moe_activation_dtype=moe_activation_dtype)
921921
dense_prefill = getattr(args, "dense_prefill", "tinygemm")
922922
_set_dequant_prefill(model, dense_prefill == "dequant")
923923
print("Exporting prefill method...")
@@ -1087,14 +1087,15 @@ def main(): # noqa: C901
10871087
"--moe-activation-dtype",
10881088
choices=["bf16", "int8"],
10891089
default="bf16",
1090-
help="MoE activation dtype for prefill only. Decode always uses bf16. bf16 (default): W4A16 batched GEMM. int8: W4A8 with INT8 tensor cores (~1.5x faster prefill).",
1090+
help="MoE activation dtype for prefill only. Decode always uses bf16. bf16 (default): W4A16 batched GEMM. int8: W4A8 with INT8 tensor cores.",
10911091
)
10921092
parser.add_argument(
10931093
"--dense-prefill",
10941094
choices=["tinygemm", "dequant"],
10951095
default="tinygemm",
1096-
help="Dense linear kernel: tinygemm (default W4A16 INT4 kernel) or "
1097-
"dequant (dequant W4→BF16 + Inductor mm for prefill, int4_matvec for decode).",
1096+
help="Dense linear prefill kernel. Decode always uses int4_matvec (Triton W4A16 vec-mat). "
1097+
"tinygemm (default): W4A16 _weight_int4pack_mm. "
1098+
"dequant: dequant W4→BF16 + cuBLAS GEMM.",
10981099
)
10991100
args = parser.parse_args()
11001101

@@ -1139,7 +1140,7 @@ def main(): # noqa: C901
11391140
"(dense weights must be W4 quantized)"
11401141
)
11411142

1142-
if args.moe_moe_activation_dtype != "bf16" and args.backend != "cuda":
1143+
if args.moe_activation_dtype != "bf16" and args.backend != "cuda":
11431144
parser.error("--moe-activation-dtype int8 requires --backend cuda")
11441145

11451146
model, config = load_and_quantize(args)

examples/models/qwen3_5_moe/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def __init__(self, config):
479479
self.hidden_size = config.hidden_size
480480
self.group_size = 32
481481
self.use_batched_moe = False
482-
self.moe_moe_activation_dtype = "bf16"
482+
self.moe_activation_dtype = "bf16"
483483

484484
self.w1_weight = nn.Parameter(
485485
torch.empty(
@@ -498,7 +498,7 @@ def __init__(self, config):
498498

499499
def forward(self, x, expert_weights, expert_indices, top_k):
500500
if self.use_batched_moe:
501-
if self.moe_moe_activation_dtype == "int8":
501+
if self.moe_activation_dtype == "int8":
502502
return torch.ops.triton.fused_moe_batched_gemm_int8(
503503
x,
504504
self.w1,

0 commit comments

Comments
 (0)