diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 1b3ed4067b..f763b37ab3 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -25,9 +25,6 @@ # Default to Opus unless CK sorting is explicitly requested. _USE_CK_MOE_SORTING = os.environ.get("AITER_USE_CK_MOE_SORTING", "0") == "1" _ACT_TYPE_DISABLED_KEY = "__ignore__" -_USE_GENERIC_SWIGLU_MXFP4_LAYOUT = ( - os.environ.get("GPTOSS_USE_GENERIC_SWIGLU_MXFP4_LAYOUT", "0") == "1" -) _SWIGLU_MXFP4_BF16_BOUND = int(os.environ.get("GPTOSS_SWIGLU_MXFP4_BF16_BOUND", "256")) @@ -308,7 +305,7 @@ def fused_moe_( # a16wi4: bf16 activations, int4 weights with groupwise scale q_dtype_a = dtypes.bf16 elif quant_type == QuantType.per_1x32: - if activation == ActivationType.Swiglu and _USE_GENERIC_SWIGLU_MXFP4_LAYOUT: + if activation == ActivationType.Swiglu and gate_mode == GateMode.SEPARATED: q_dtype_a = dtypes.bf16 if M < _SWIGLU_MXFP4_BF16_BOUND else dtypes.fp4x2 elif activation == ActivationType.Swiglu or gate_mode == GateMode.INTERLEAVE: if get_gfx() != "gfx950" or M < bf16_fp8_bound: @@ -334,6 +331,7 @@ def fused_moe_( hidden_pad, intermediate_pad, isShuffled, + gate_mode, ) block_size_M = metadata.block_m if block_size_M is None else block_size_M @@ -410,6 +408,7 @@ def fused_moe_( topk_weights=topk_weight, # only for flydsl dsv4 swiglu_limit=swiglu_limit, + gate_mode=gate_mode, ) @@ -819,7 +818,9 @@ def get_2stage_cfgs( hidden_pad, intermediate_pad, is_shuffled=True, + gate_mode=GateMode.SEPARATED.value, ): + gate_mode = GateMode(gate_mode) _INDEX_COLS = [ "cu_num", "token", @@ -1160,7 +1161,7 @@ def get_block_m() -> int: stage2_has_bias=enable_bias and is_flydsl2, ) if ( - not _USE_GENERIC_SWIGLU_MXFP4_LAYOUT + gate_mode != GateMode.SEPARATED and dtype in [dtypes.bf16, dtypes.fp16] and q_type == QuantType.per_1x32 and activation == ActivationType.Swiglu @@ -1449,8 +1450,10 @@ def fused_moe_2stages( topk_ids=None, topk_weights=None, swiglu_limit=0.0, + gate_mode=GateMode.SEPARATED.value, ): quant_func = get_quant(quant_type) + gate_mode = GateMode(gate_mode) token_num, _ = hidden_states.shape E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) dtype = moe_out.dtype @@ -1472,6 +1475,7 @@ def fused_moe_2stages( hidden_pad, intermediate_pad, is_shuffled, + gate_mode, ) if ( quant_type == QuantType.per_1x32 diff --git a/aiter/ops/flydsl/kernels/silu_and_mul_fq.py b/aiter/ops/flydsl/kernels/silu_and_mul_fq.py index 41e62c7848..eb4dcca312 100644 --- a/aiter/ops/flydsl/kernels/silu_and_mul_fq.py +++ b/aiter/ops/flydsl/kernels/silu_and_mul_fq.py @@ -297,10 +297,13 @@ def _f32_to_e2m1(qx_f32): ) gate = g linear = u - - t = g * neg_log2e - - if const_expr(swiglu_limit != 0 and act != "swiglu"): + t = gate * neg_log2e + if const_expr(act == "swiglu"): + gate = arith.minimumf(gate, _limit) + linear = arith.minimumf(linear, _limit) + linear = arith.maximumf(linear, _neg_limit) + t = gate * swiglu_neg_alpha_log2e + elif const_expr(swiglu_limit != 0 and act != "swiglu"): gate = arith.minimumf(gate, _limit) linear = arith.minimumf(linear, _limit) linear = arith.maximumf(linear, _neg_limit) @@ -314,9 +317,10 @@ def _f32_to_e2m1(qx_f32): f32, "llvm.amdgcn.rcp.f32", [den], [], [] ) if const_expr(act == "swiglu"): - act_vals.append(gate * sig * (linear + c1_f32)) + act_v = gate * sig * (linear + c1_f32) else: - act_vals.append(gate * sig * linear) + act_v = gate * sig * linear + act_vals.append(act_v) if const_expr(_need_quant): local_max = c0_f32 diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index eb77cdbfdb..c47b42ce81 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -22,6 +22,8 @@ from aiter.fused_moe import ( fused_topk, fused_moe, + get_2stage_cfgs, + get_padded_M, torch_moe_stage1, torch_moe_stage2, ) @@ -244,13 +246,47 @@ def weight_per_128x128_quant(weight, quant_dtype): w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) # # ######################## stage 1 start ########### + stage1_ref_dtype = dtype + if ( + actType == aiter.ActivationType.Swiglu + and qType == aiter.QuantType.per_1x32 + and WQDType == dtypes.fp4x2 + ): + runtime_aq_dtype = _runtime_swiglu_mxfp4_q_dtype_a( + token, actType, gateMode, qType, AQDType, WQDType + ) + if runtime_aq_dtype == dtypes.fp4x2: + metadata = get_2stage_cfgs( + get_padded_M(token), + model_dim, + inter_dim, + E, + topk, + dtype, + runtime_aq_dtype, + WQDType, + qType, + w1.shape[1] == (inter_dim * 2), + actType, + doweight_stage1, + hidden_pad, + intermediate_pad, + getattr(w1_qt_aiter, "is_shuffled", False) + or getattr(w2_qt_aiter, "is_shuffled", False), + gateMode, + ) + if metadata.fuse_quant == "fp4": + # Fused Swiglu MXFP4 quantizes the f32 activation directly. + # Keep the torch reference at f32 until the quantization step. + stage1_ref_dtype = dtypes.fp32 + out1_ref = torch_moe_stage1( a1_qt, w1_qt, w2_qt, topk_weights, topk_ids, - dtype=dtype, + dtype=stage1_ref_dtype, activation=actType, quant_type=qType, a1_scale=a1_scale, @@ -544,6 +580,8 @@ def _row_to_kwargs(row): aq_dtype = _str2dtype(row["q_dtype_a"]) wq_dtype = _str2dtype(row["q_dtype_w"]) act_type = _str2enum(row["act_type"], aiter.ActivationType) + # Tuned CSV rows do not carry gate mode explicitly. Infer the runtime mode + # from the selected activation/weight dtype layout used by fused_moe. gate_mode = _effective_gate_mode(aq_dtype, wq_dtype) return dict( dtype=_str2dtype(row["dtype"]), @@ -594,6 +632,7 @@ def _iter_csv_cases(): expected_aq_dtype = _runtime_swiglu_mxfp4_q_dtype_a( kwargs["token"], kwargs["actType"], + kwargs["gateMode"], kwargs["qType"], kwargs["AQDType"], kwargs["WQDType"], @@ -634,7 +673,9 @@ def _effective_swiglu_limit(quant_type, aq_dtype, wq_dtype, swiglu_limit): return 0.0 -def _runtime_swiglu_mxfp4_q_dtype_a(token, act_type, q_type, aq_dtype, wq_dtype): +def _runtime_swiglu_mxfp4_q_dtype_a( + token, act_type, gate_mode, q_type, aq_dtype, wq_dtype +): """Return the q_dtype_a that fused_moe will select for Swiglu MXFP4.""" if act_type != aiter.ActivationType.Swiglu: return None @@ -643,11 +684,12 @@ def _runtime_swiglu_mxfp4_q_dtype_a(token, act_type, q_type, aq_dtype, wq_dtype) if aq_dtype not in [dtypes.bf16, dtypes.fp16, dtypes.fp8, dtypes.fp4x2]: return None - if os.environ.get("GPTOSS_USE_GENERIC_SWIGLU_MXFP4_LAYOUT", "0") == "1": + gate_mode = GateMode(gate_mode) + if gate_mode == GateMode.SEPARATED: bound = int(os.environ.get("GPTOSS_SWIGLU_MXFP4_BF16_BOUND", "256")) return dtypes.bf16 if token < bound else dtypes.fp4x2 - bound = int(os.environ.get("AITER_BF16_FP8_BOUND", "512")) + bound = int(os.environ.get("AITER_BF16_FP8_MOE_BOUND", "256")) return dtypes.bf16 if get_gfx() != "gfx950" or token < bound else dtypes.fp8