Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,8 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then
TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \
python -m executorch.examples.models.qwen3_5_moe.export \
--prequantized "$LOCAL_MODEL_DIR" \
--output-dir "${OUTPUT_DIR}"
--output-dir "${OUTPUT_DIR}" \
--moe-activation-dtype int8
echo "::endgroup::"

test -f "${OUTPUT_DIR}/model.pte"
Expand Down
152 changes: 152 additions & 0 deletions backends/cuda/tests/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from executorch.backends.cuda.triton.kernels.fused_moe import (
fused_moe as triton_fused_moe,
fused_moe_batched as triton_fused_moe_batched,
fused_moe_batched_gemm_int8 as triton_fused_moe_batched_int8,
moe_align_block_size,
)
from executorch.exir import (
Expand Down Expand Up @@ -212,6 +213,11 @@ def _run_cpp_runner(runner_path, pte_path, ptd_path, input_files, output_base):


class TestFusedMoE(unittest.TestCase):
# TODO: migrate from manual max_abs/max_ref relative checks to
# torch.allclose(atol=, rtol=). Current tests use per-tensor-max relative
# error which is looser than per-element allclose — need to calibrate atol
# for INT4 quantization noise floor across random weight magnitudes.

def setUp(self):
if not torch.cuda.is_available():
self.skipTest("CUDA is not available")
Expand Down Expand Up @@ -487,6 +493,152 @@ def test_e2e_cpp_runner(self):
)


class TestFusedMoEBatchedInt8(unittest.TestCase):
"""Correctness tests for the INT8 dynamic-activation batched MoE kernel."""

INT8_TEST_CONFIGS = [
(42, 8, 64, 32, 4, 2, 32, "8tok_small"),
(7, 16, 64, 32, 8, 4, 32, "16tok_8exp_top4"),
(13, 32, 128, 64, 8, 2, 64, "32tok_gs64"),
(55, 64, 64, 32, 4, 2, 32, "64tok"),
(99, 128, 128, 64, 8, 2, 32, "128tok"),
(0, 256, 128, 64, 8, 2, 32, "256tok"),
]

def test_int8_correctness(self):
"""INT8 batched kernel matches reference across M values."""
for (
seed,
M,
hidden,
intermediate,
num_experts,
top_k,
gs,
desc,
) in self.INT8_TEST_CONFIGS:
with self.subTest(desc=desc):
torch.manual_seed(seed)
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
w1_weight = torch.randn(
num_experts,
2 * intermediate,
hidden,
dtype=torch.bfloat16,
device="cuda",
)
w2_weight = torch.randn(
num_experts,
hidden,
intermediate,
dtype=torch.bfloat16,
device="cuda",
)
w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs)
w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs)
w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda()

scores = torch.randn(M, num_experts, device="cuda")
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
topk_weights = topk_weights.softmax(dim=-1).float()

out_int8 = triton_fused_moe_batched_int8(
x,
w1,
w1s,
w2,
w2s,
topk_weights,
topk_ids,
top_k,
num_experts,
gs,
)

w1_dq = _dequantize_int4(w1.cpu(), w1s.cpu(), gs).cuda()
w2_dq = _dequantize_int4(w2.cpu(), w2s.cpu(), gs).cuda()
ref = _reference_moe(x, w1_dq, w2_dq, topk_weights, topk_ids, top_k)

diff = (out_int8.float() - ref.float()).abs().max().item()
rel = diff / (ref.float().abs().max().item() + 1e-10)
self.assertLess(
rel,
0.10,
f"{desc}: relative diff {rel:.4f} (abs {diff:.6f})",
)

def test_int8_matches_bf16_batched(self):
"""INT8 batched output is close to BF16 batched output."""
for (
seed,
M,
hidden,
intermediate,
num_experts,
top_k,
gs,
desc,
) in self.INT8_TEST_CONFIGS:
with self.subTest(desc=desc):
torch.manual_seed(seed)
x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda")
w1_weight = torch.randn(
num_experts,
2 * intermediate,
hidden,
dtype=torch.bfloat16,
device="cuda",
)
w2_weight = torch.randn(
num_experts,
hidden,
intermediate,
dtype=torch.bfloat16,
device="cuda",
)
w1, w1s = _quantize_weights_int4(w1_weight.cpu(), gs)
w2, w2s = _quantize_weights_int4(w2_weight.cpu(), gs)
w1, w1s, w2, w2s = w1.cuda(), w1s.cuda(), w2.cuda(), w2s.cuda()

scores = torch.randn(M, num_experts, device="cuda")
topk_weights, topk_ids = torch.topk(scores, top_k, dim=-1)
topk_weights = topk_weights.softmax(dim=-1).float()

out_bf16 = triton_fused_moe_batched(
x,
w1,
w1s,
w2,
w2s,
topk_weights,
topk_ids,
top_k,
num_experts,
gs,
)

out_int8 = triton_fused_moe_batched_int8(
x,
w1,
w1s,
w2,
w2s,
topk_weights,
topk_ids,
top_k,
num_experts,
gs,
)

diff = (out_int8.float() - out_bf16.float()).abs().max().item()
rel = diff / (out_bf16.float().abs().max().item() + 1e-10)
self.assertLess(
rel,
0.15,
f"{desc}: int8 vs bf16 relative diff {rel:.4f} (abs {diff:.6f})",
)


class TestMoeAlignBlockSize(unittest.TestCase):
def setUp(self):
if not torch.cuda.is_available():
Expand Down
Loading
Loading