From abcff711e32372192ff4ff1e313e53cc327ad9f9 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Mon, 11 May 2026 02:44:51 -0500 Subject: [PATCH 1/2] [MoE] align Swiglu MXFP4 fused quant paths Remove the GPT-OSS Swiglu layout env switch in favor of GateMode, align the CSV test filter with runtime dtype selection, and restore FlyDSL Swiglu _fp4 fused quant accuracy by matching the non-fused bf16 stage1 semantics. Co-authored-by: Cursor --- aiter/fused_moe.py | 14 +++++++++----- .../flydsl/kernels/mixed_moe_gemm_2stage.py | 5 +++++ aiter/ops/flydsl/kernels/silu_and_mul_fq.py | 18 ++++++++++++++---- op_tests/test_moe_2stage.py | 10 +++++++--- 4 files changed, 35 insertions(+), 12 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 1b3ed4067b..f763b37ab3 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -25,9 +25,6 @@ # Default to Opus unless CK sorting is explicitly requested. _USE_CK_MOE_SORTING = os.environ.get("AITER_USE_CK_MOE_SORTING", "0") == "1" _ACT_TYPE_DISABLED_KEY = "__ignore__" -_USE_GENERIC_SWIGLU_MXFP4_LAYOUT = ( - os.environ.get("GPTOSS_USE_GENERIC_SWIGLU_MXFP4_LAYOUT", "0") == "1" -) _SWIGLU_MXFP4_BF16_BOUND = int(os.environ.get("GPTOSS_SWIGLU_MXFP4_BF16_BOUND", "256")) @@ -308,7 +305,7 @@ def fused_moe_( # a16wi4: bf16 activations, int4 weights with groupwise scale q_dtype_a = dtypes.bf16 elif quant_type == QuantType.per_1x32: - if activation == ActivationType.Swiglu and _USE_GENERIC_SWIGLU_MXFP4_LAYOUT: + if activation == ActivationType.Swiglu and gate_mode == GateMode.SEPARATED: q_dtype_a = dtypes.bf16 if M < _SWIGLU_MXFP4_BF16_BOUND else dtypes.fp4x2 elif activation == ActivationType.Swiglu or gate_mode == GateMode.INTERLEAVE: if get_gfx() != "gfx950" or M < bf16_fp8_bound: @@ -334,6 +331,7 @@ def fused_moe_( hidden_pad, intermediate_pad, isShuffled, + gate_mode, ) block_size_M = metadata.block_m if block_size_M is None else block_size_M @@ -410,6 +408,7 @@ def fused_moe_( topk_weights=topk_weight, # only for flydsl dsv4 swiglu_limit=swiglu_limit, + gate_mode=gate_mode, ) @@ -819,7 +818,9 @@ def get_2stage_cfgs( hidden_pad, intermediate_pad, is_shuffled=True, + gate_mode=GateMode.SEPARATED.value, ): + gate_mode = GateMode(gate_mode) _INDEX_COLS = [ "cu_num", "token", @@ -1160,7 +1161,7 @@ def get_block_m() -> int: stage2_has_bias=enable_bias and is_flydsl2, ) if ( - not _USE_GENERIC_SWIGLU_MXFP4_LAYOUT + gate_mode != GateMode.SEPARATED and dtype in [dtypes.bf16, dtypes.fp16] and q_type == QuantType.per_1x32 and activation == ActivationType.Swiglu @@ -1449,8 +1450,10 @@ def fused_moe_2stages( topk_ids=None, topk_weights=None, swiglu_limit=0.0, + gate_mode=GateMode.SEPARATED.value, ): quant_func = get_quant(quant_type) + gate_mode = GateMode(gate_mode) token_num, _ = hidden_states.shape E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) dtype = moe_out.dtype @@ -1472,6 +1475,7 @@ def fused_moe_2stages( hidden_pad, intermediate_pad, is_shuffled, + gate_mode, ) if ( quant_type == QuantType.per_1x32 diff --git a/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py b/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py index 83371d35b0..2fd16c410c 100644 --- a/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py +++ b/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py @@ -2139,6 +2139,11 @@ def write_row_to_lds( if const_expr(_apply_weight): v = v * tw if const_expr(_need_quant): + if const_expr(_need_fp4 and act == "swiglu"): + # Match the non-fused Swiglu path: stage1 + # materializes bf16 before MXFP4 quantization. + v_bf16 = arith.trunc_f(T.bf16, v) + v = v_bf16.extf(T.f32) lds_idx = row_base_lds + col_local vec1_f32 = T.vec(1, f32) v1 = vector.from_elements(vec1_f32, [v]) diff --git a/aiter/ops/flydsl/kernels/silu_and_mul_fq.py b/aiter/ops/flydsl/kernels/silu_and_mul_fq.py index 41e62c7848..e4ab3b7ff0 100644 --- a/aiter/ops/flydsl/kernels/silu_and_mul_fq.py +++ b/aiter/ops/flydsl/kernels/silu_and_mul_fq.py @@ -298,13 +298,18 @@ def _f32_to_e2m1(qx_f32): gate = g linear = u - t = g * neg_log2e + t = gate * neg_log2e - if const_expr(swiglu_limit != 0 and act != "swiglu"): + if const_expr(act == "swiglu"): gate = arith.minimumf(gate, _limit) linear = arith.minimumf(linear, _limit) linear = arith.maximumf(linear, _neg_limit) t = gate * swiglu_neg_alpha_log2e + elif const_expr(swiglu_limit != 0): + gate = arith.minimumf(gate, _limit) + linear = arith.minimumf(linear, _limit) + linear = arith.maximumf(linear, _neg_limit) + t = gate * neg_log2e emu = llvm.call_intrinsic( f32, "llvm.amdgcn.exp2.f32", [t], [], [] @@ -314,9 +319,14 @@ def _f32_to_e2m1(qx_f32): f32, "llvm.amdgcn.rcp.f32", [den], [], [] ) if const_expr(act == "swiglu"): - act_vals.append(gate * sig * (linear + c1_f32)) + act_v = gate * sig * (linear + c1_f32) else: - act_vals.append(gate * sig * linear) + act_v = gate * sig * linear + if const_expr(_need_fp4 and act == "swiglu"): + # Keep fused quant numerically aligned with the + # non-fused path, which stores stage1 as bf16 first. + act_v = arith.trunc_f(T.bf16, act_v).extf(T.f32) + act_vals.append(act_v) if const_expr(_need_quant): local_max = c0_f32 diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index eb77cdbfdb..105b168ab0 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -594,6 +594,7 @@ def _iter_csv_cases(): expected_aq_dtype = _runtime_swiglu_mxfp4_q_dtype_a( kwargs["token"], kwargs["actType"], + kwargs["gateMode"], kwargs["qType"], kwargs["AQDType"], kwargs["WQDType"], @@ -634,7 +635,9 @@ def _effective_swiglu_limit(quant_type, aq_dtype, wq_dtype, swiglu_limit): return 0.0 -def _runtime_swiglu_mxfp4_q_dtype_a(token, act_type, q_type, aq_dtype, wq_dtype): +def _runtime_swiglu_mxfp4_q_dtype_a( + token, act_type, gate_mode, q_type, aq_dtype, wq_dtype +): """Return the q_dtype_a that fused_moe will select for Swiglu MXFP4.""" if act_type != aiter.ActivationType.Swiglu: return None @@ -643,11 +646,12 @@ def _runtime_swiglu_mxfp4_q_dtype_a(token, act_type, q_type, aq_dtype, wq_dtype) if aq_dtype not in [dtypes.bf16, dtypes.fp16, dtypes.fp8, dtypes.fp4x2]: return None - if os.environ.get("GPTOSS_USE_GENERIC_SWIGLU_MXFP4_LAYOUT", "0") == "1": + gate_mode = GateMode(gate_mode) + if gate_mode == GateMode.SEPARATED: bound = int(os.environ.get("GPTOSS_SWIGLU_MXFP4_BF16_BOUND", "256")) return dtypes.bf16 if token < bound else dtypes.fp4x2 - bound = int(os.environ.get("AITER_BF16_FP8_BOUND", "512")) + bound = int(os.environ.get("AITER_BF16_FP8_MOE_BOUND", "256")) return dtypes.bf16 if get_gfx() != "gfx950" or token < bound else dtypes.fp8 From ad01c66f48e4d6c0497feec5ef3d442cd06fb3e6 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Mon, 11 May 2026 05:09:25 -0500 Subject: [PATCH 2/2] [MoE] keep Swiglu MXFP4 fused quant in fp32 --- .../flydsl/kernels/mixed_moe_gemm_2stage.py | 5 --- aiter/ops/flydsl/kernels/silu_and_mul_fq.py | 10 +---- op_tests/test_moe_2stage.py | 40 ++++++++++++++++++- 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py b/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py index 2fd16c410c..83371d35b0 100644 --- a/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py +++ b/aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py @@ -2139,11 +2139,6 @@ def write_row_to_lds( if const_expr(_apply_weight): v = v * tw if const_expr(_need_quant): - if const_expr(_need_fp4 and act == "swiglu"): - # Match the non-fused Swiglu path: stage1 - # materializes bf16 before MXFP4 quantization. - v_bf16 = arith.trunc_f(T.bf16, v) - v = v_bf16.extf(T.f32) lds_idx = row_base_lds + col_local vec1_f32 = T.vec(1, f32) v1 = vector.from_elements(vec1_f32, [v]) diff --git a/aiter/ops/flydsl/kernels/silu_and_mul_fq.py b/aiter/ops/flydsl/kernels/silu_and_mul_fq.py index e4ab3b7ff0..eb4dcca312 100644 --- a/aiter/ops/flydsl/kernels/silu_and_mul_fq.py +++ b/aiter/ops/flydsl/kernels/silu_and_mul_fq.py @@ -297,19 +297,17 @@ def _f32_to_e2m1(qx_f32): ) gate = g linear = u - t = gate * neg_log2e - if const_expr(act == "swiglu"): gate = arith.minimumf(gate, _limit) linear = arith.minimumf(linear, _limit) linear = arith.maximumf(linear, _neg_limit) t = gate * swiglu_neg_alpha_log2e - elif const_expr(swiglu_limit != 0): + elif const_expr(swiglu_limit != 0 and act != "swiglu"): gate = arith.minimumf(gate, _limit) linear = arith.minimumf(linear, _limit) linear = arith.maximumf(linear, _neg_limit) - t = gate * neg_log2e + t = gate * swiglu_neg_alpha_log2e emu = llvm.call_intrinsic( f32, "llvm.amdgcn.exp2.f32", [t], [], [] @@ -322,10 +320,6 @@ def _f32_to_e2m1(qx_f32): act_v = gate * sig * (linear + c1_f32) else: act_v = gate * sig * linear - if const_expr(_need_fp4 and act == "swiglu"): - # Keep fused quant numerically aligned with the - # non-fused path, which stores stage1 as bf16 first. - act_v = arith.trunc_f(T.bf16, act_v).extf(T.f32) act_vals.append(act_v) if const_expr(_need_quant): diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index 105b168ab0..c47b42ce81 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -22,6 +22,8 @@ from aiter.fused_moe import ( fused_topk, fused_moe, + get_2stage_cfgs, + get_padded_M, torch_moe_stage1, torch_moe_stage2, ) @@ -244,13 +246,47 @@ def weight_per_128x128_quant(weight, quant_dtype): w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) # # ######################## stage 1 start ########### + stage1_ref_dtype = dtype + if ( + actType == aiter.ActivationType.Swiglu + and qType == aiter.QuantType.per_1x32 + and WQDType == dtypes.fp4x2 + ): + runtime_aq_dtype = _runtime_swiglu_mxfp4_q_dtype_a( + token, actType, gateMode, qType, AQDType, WQDType + ) + if runtime_aq_dtype == dtypes.fp4x2: + metadata = get_2stage_cfgs( + get_padded_M(token), + model_dim, + inter_dim, + E, + topk, + dtype, + runtime_aq_dtype, + WQDType, + qType, + w1.shape[1] == (inter_dim * 2), + actType, + doweight_stage1, + hidden_pad, + intermediate_pad, + getattr(w1_qt_aiter, "is_shuffled", False) + or getattr(w2_qt_aiter, "is_shuffled", False), + gateMode, + ) + if metadata.fuse_quant == "fp4": + # Fused Swiglu MXFP4 quantizes the f32 activation directly. + # Keep the torch reference at f32 until the quantization step. + stage1_ref_dtype = dtypes.fp32 + out1_ref = torch_moe_stage1( a1_qt, w1_qt, w2_qt, topk_weights, topk_ids, - dtype=dtype, + dtype=stage1_ref_dtype, activation=actType, quant_type=qType, a1_scale=a1_scale, @@ -544,6 +580,8 @@ def _row_to_kwargs(row): aq_dtype = _str2dtype(row["q_dtype_a"]) wq_dtype = _str2dtype(row["q_dtype_w"]) act_type = _str2enum(row["act_type"], aiter.ActivationType) + # Tuned CSV rows do not carry gate mode explicitly. Infer the runtime mode + # from the selected activation/weight dtype layout used by fused_moe. gate_mode = _effective_gate_mode(aq_dtype, wq_dtype) return dict( dtype=_str2dtype(row["dtype"]),