Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
# Default to Opus unless CK sorting is explicitly requested.
_USE_CK_MOE_SORTING = os.environ.get("AITER_USE_CK_MOE_SORTING", "0") == "1"
_ACT_TYPE_DISABLED_KEY = "__ignore__"
_USE_GENERIC_SWIGLU_MXFP4_LAYOUT = (
os.environ.get("GPTOSS_USE_GENERIC_SWIGLU_MXFP4_LAYOUT", "0") == "1"
)
_SWIGLU_MXFP4_BF16_BOUND = int(os.environ.get("GPTOSS_SWIGLU_MXFP4_BF16_BOUND", "256"))


Expand Down Expand Up @@ -308,7 +305,7 @@ def fused_moe_(
# a16wi4: bf16 activations, int4 weights with groupwise scale
q_dtype_a = dtypes.bf16
elif quant_type == QuantType.per_1x32:
if activation == ActivationType.Swiglu and _USE_GENERIC_SWIGLU_MXFP4_LAYOUT:
if activation == ActivationType.Swiglu and gate_mode == GateMode.SEPARATED:
Comment thread
coderfeli marked this conversation as resolved.
q_dtype_a = dtypes.bf16 if M < _SWIGLU_MXFP4_BF16_BOUND else dtypes.fp4x2
elif activation == ActivationType.Swiglu or gate_mode == GateMode.INTERLEAVE:
if get_gfx() != "gfx950" or M < bf16_fp8_bound:
Expand All @@ -334,6 +331,7 @@ def fused_moe_(
hidden_pad,
intermediate_pad,
isShuffled,
gate_mode,
)

block_size_M = metadata.block_m if block_size_M is None else block_size_M
Expand Down Expand Up @@ -410,6 +408,7 @@ def fused_moe_(
topk_weights=topk_weight,
# only for flydsl dsv4
swiglu_limit=swiglu_limit,
gate_mode=gate_mode,
)


Expand Down Expand Up @@ -819,7 +818,9 @@ def get_2stage_cfgs(
hidden_pad,
intermediate_pad,
is_shuffled=True,
gate_mode=GateMode.SEPARATED.value,
):
gate_mode = GateMode(gate_mode)
_INDEX_COLS = [
"cu_num",
"token",
Expand Down Expand Up @@ -1160,7 +1161,7 @@ def get_block_m() -> int:
stage2_has_bias=enable_bias and is_flydsl2,
)
if (
not _USE_GENERIC_SWIGLU_MXFP4_LAYOUT
gate_mode != GateMode.SEPARATED
and dtype in [dtypes.bf16, dtypes.fp16]
and q_type == QuantType.per_1x32
and activation == ActivationType.Swiglu
Expand Down Expand Up @@ -1449,8 +1450,10 @@ def fused_moe_2stages(
topk_ids=None,
topk_weights=None,
swiglu_limit=0.0,
gate_mode=GateMode.SEPARATED.value,
):
quant_func = get_quant(quant_type)
gate_mode = GateMode(gate_mode)
token_num, _ = hidden_states.shape
E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape)
dtype = moe_out.dtype
Expand All @@ -1472,6 +1475,7 @@ def fused_moe_2stages(
hidden_pad,
intermediate_pad,
is_shuffled,
gate_mode,
)
if (
quant_type == QuantType.per_1x32
Expand Down
16 changes: 10 additions & 6 deletions aiter/ops/flydsl/kernels/silu_and_mul_fq.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,13 @@ def _f32_to_e2m1(qx_f32):
)
gate = g
linear = u

t = g * neg_log2e

if const_expr(swiglu_limit != 0 and act != "swiglu"):
t = gate * neg_log2e
if const_expr(act == "swiglu"):
gate = arith.minimumf(gate, _limit)
linear = arith.minimumf(linear, _limit)
linear = arith.maximumf(linear, _neg_limit)
t = gate * swiglu_neg_alpha_log2e
elif const_expr(swiglu_limit != 0 and act != "swiglu"):
gate = arith.minimumf(gate, _limit)
linear = arith.minimumf(linear, _limit)
linear = arith.maximumf(linear, _neg_limit)
Expand All @@ -314,9 +317,10 @@ def _f32_to_e2m1(qx_f32):
f32, "llvm.amdgcn.rcp.f32", [den], [], []
)
if const_expr(act == "swiglu"):
act_vals.append(gate * sig * (linear + c1_f32))
act_v = gate * sig * (linear + c1_f32)
else:
act_vals.append(gate * sig * linear)
act_v = gate * sig * linear
act_vals.append(act_v)

if const_expr(_need_quant):
local_max = c0_f32
Expand Down
50 changes: 46 additions & 4 deletions op_tests/test_moe_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from aiter.fused_moe import (
fused_topk,
fused_moe,
get_2stage_cfgs,
get_padded_M,
torch_moe_stage1,
torch_moe_stage2,
)
Expand Down Expand Up @@ -244,13 +246,47 @@ def weight_per_128x128_quant(weight, quant_dtype):
w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale)

# # ######################## stage 1 start ###########
stage1_ref_dtype = dtype
if (
actType == aiter.ActivationType.Swiglu
and qType == aiter.QuantType.per_1x32
and WQDType == dtypes.fp4x2
):
runtime_aq_dtype = _runtime_swiglu_mxfp4_q_dtype_a(
token, actType, gateMode, qType, AQDType, WQDType
)
if runtime_aq_dtype == dtypes.fp4x2:
metadata = get_2stage_cfgs(
get_padded_M(token),
model_dim,
inter_dim,
E,
topk,
dtype,
runtime_aq_dtype,
WQDType,
qType,
w1.shape[1] == (inter_dim * 2),
actType,
doweight_stage1,
hidden_pad,
intermediate_pad,
getattr(w1_qt_aiter, "is_shuffled", False)
or getattr(w2_qt_aiter, "is_shuffled", False),
gateMode,
)
if metadata.fuse_quant == "fp4":
# Fused Swiglu MXFP4 quantizes the f32 activation directly.
# Keep the torch reference at f32 until the quantization step.
stage1_ref_dtype = dtypes.fp32

out1_ref = torch_moe_stage1(
a1_qt,
w1_qt,
w2_qt,
topk_weights,
topk_ids,
dtype=dtype,
dtype=stage1_ref_dtype,
activation=actType,
quant_type=qType,
a1_scale=a1_scale,
Expand Down Expand Up @@ -544,6 +580,8 @@ def _row_to_kwargs(row):
aq_dtype = _str2dtype(row["q_dtype_a"])
wq_dtype = _str2dtype(row["q_dtype_w"])
act_type = _str2enum(row["act_type"], aiter.ActivationType)
# Tuned CSV rows do not carry gate mode explicitly. Infer the runtime mode
# from the selected activation/weight dtype layout used by fused_moe.
gate_mode = _effective_gate_mode(aq_dtype, wq_dtype)
return dict(
dtype=_str2dtype(row["dtype"]),
Expand Down Expand Up @@ -594,6 +632,7 @@ def _iter_csv_cases():
expected_aq_dtype = _runtime_swiglu_mxfp4_q_dtype_a(
kwargs["token"],
kwargs["actType"],
kwargs["gateMode"],
kwargs["qType"],
kwargs["AQDType"],
kwargs["WQDType"],
Expand Down Expand Up @@ -634,7 +673,9 @@ def _effective_swiglu_limit(quant_type, aq_dtype, wq_dtype, swiglu_limit):
return 0.0


def _runtime_swiglu_mxfp4_q_dtype_a(token, act_type, q_type, aq_dtype, wq_dtype):
def _runtime_swiglu_mxfp4_q_dtype_a(
token, act_type, gate_mode, q_type, aq_dtype, wq_dtype
):
"""Return the q_dtype_a that fused_moe will select for Swiglu MXFP4."""
if act_type != aiter.ActivationType.Swiglu:
return None
Expand All @@ -643,11 +684,12 @@ def _runtime_swiglu_mxfp4_q_dtype_a(token, act_type, q_type, aq_dtype, wq_dtype)
if aq_dtype not in [dtypes.bf16, dtypes.fp16, dtypes.fp8, dtypes.fp4x2]:
return None

if os.environ.get("GPTOSS_USE_GENERIC_SWIGLU_MXFP4_LAYOUT", "0") == "1":
gate_mode = GateMode(gate_mode)
if gate_mode == GateMode.SEPARATED:
bound = int(os.environ.get("GPTOSS_SWIGLU_MXFP4_BF16_BOUND", "256"))
return dtypes.bf16 if token < bound else dtypes.fp4x2

bound = int(os.environ.get("AITER_BF16_FP8_BOUND", "512"))
bound = int(os.environ.get("AITER_BF16_FP8_MOE_BOUND", "256"))
return dtypes.bf16 if get_gfx() != "gfx950" or token < bound else dtypes.fp8


Expand Down
Loading