@@ -301,6 +301,22 @@ struct QuantParams
301301 GemmInputs fc2;
302302 } mxfp8_mxfp4;
303303
304+ // MXFP8 x MXFP8 quantization params (W8A8 with UE8M0 1x32 block scales on both
305+ // sides). No per-tensor / global alpha (block scales determine output magnitude).
306+ // Kept as a separate slot from mxfp8_mxfp4 so consumers can disambiguate
307+ // B element bitwidth (4-bit vs 8-bit) without relying on naming aliases.
308+ struct MXFP8MXFP8Inputs
309+ {
310+ struct GemmInputs
311+ {
312+ TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const * weight_block_scale
313+ = nullptr ; // (experts, n, k / 32)
314+ };
315+
316+ GemmInputs fc1;
317+ GemmInputs fc2;
318+ } mxfp8_mxfp8;
319+
304320 // FP4 quantization params
305321 struct FP4Inputs
306322 {
@@ -404,6 +420,36 @@ struct QuantParams
404420 return qp;
405421 }
406422
423+ // MXFP8xMXFP8 grouped MoE: e4m3 acts + e4m3 weights, UE8M0 1x32 block
424+ // scales on both sides. No per-tensor / global alpha is required (block
425+ // scales determine output magnitude). Writes to its own dedicated slot
426+ // so consumers can distinguish from MXFP8xMXFP4 (which has 4-bit B).
427+ static QuantParams MXFP8MXFP8 (TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const * fc1_weight_block_scale,
428+ TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const * fc2_weight_block_scale)
429+ {
430+ QuantParams qp;
431+ qp.mxfp8_mxfp8 .fc1 = {fc1_weight_block_scale};
432+ qp.mxfp8_mxfp8 .fc2 = {fc2_weight_block_scale};
433+ return qp;
434+ }
435+
436+ // Helpers: return the active MXFPX activation-side block-scale pointer
437+ // regardless of whether B is fp4 (mxfp8_mxfp4) or fp8 (mxfp8_mxfp8).
438+ // Used by consumers that only care "is the activation path block-scaled
439+ // MXFPX" — they don't need to know B's bitwidth, that's already encoded
440+ // in the kernel template instantiation.
441+ TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const * mxfpxActFc1WeightScale () const
442+ {
443+ return mxfp8_mxfp8.fc1 .weight_block_scale ? mxfp8_mxfp8.fc1 .weight_block_scale
444+ : mxfp8_mxfp4.fc1 .weight_block_scale ;
445+ }
446+
447+ TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const * mxfpxActFc2WeightScale () const
448+ {
449+ return mxfp8_mxfp8.fc2 .weight_block_scale ? mxfp8_mxfp8.fc2 .weight_block_scale
450+ : mxfp8_mxfp4.fc2 .weight_block_scale ;
451+ }
452+
407453 static QuantParams FP4 (float const * fc1_act_global_scale,
408454 TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const * fc1_weight_block_scale,
409455 float const * fc1_global_scale, //
@@ -747,7 +793,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
747793
748794 virtual size_t getGemmWorkspaceSize (int num_experts_per_node) const override
749795 {
750- return moe_gemm_runner_.getMaxWorkspaceSize (num_experts_per_node);
796+ return moe_gemm_runner_.getMaxWorkspaceSize (num_experts_per_node, use_mxfp8_weight_scaling_ );
751797 }
752798
753799 std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -769,7 +815,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
769815 reinterpret_cast <ScaleBiasType const *>(bias1), reinterpret_cast <ScaleBiasType const *>(bias2),
770816 reinterpret_cast <UnfusedGemmOutputType*>(gemm1_output),
771817 reinterpret_cast <UnfusedGemmOutputType*>(gemm2_output), router_scales, permuted_row_to_unpermuted_row,
772- stream);
818+ Self::getScalingType (use_mxfp8_weight_scaling_), stream);
773819 }
774820
775821 std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -802,21 +848,17 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
802848 bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool use_lora, int start_expert,
803849 MOEParallelismConfig parallelism_config, cudaStream_t stream);
804850
805- // Non-static so it can read use_mxfp8_weight_scaling_ via getScalingType()
806- // when picking the runtime FpXBlockScalingType for the <e4m3, e4m3>
807- // template. The only caller (in moe_kernels.cu) is already inside an
808- // instance method, so this is a no-op for existing call sites.
809- std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> computeStridesTmaWarpSpecialized (
810- int64_t const * expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput layout_info1,
811- TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t expanded_num_tokens,
812- int64_t gemm1_n, int64_t gemm1_k, int64_t gemm2_n, int64_t gemm2_k, int const num_experts_per_node,
813- T const * gemm1_in, T const * gemm2_in, WeightType const * weights1, WeightType const * weights2,
814- float const * alpha_scale_flat1, float const * alpha_scale_flat2,
815- TmaWarpSpecializedGroupedGemmInput::ElementSF const * fp4_act_flat1,
851+ static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
852+ computeStridesTmaWarpSpecialized (int64_t const * expert_first_token_offset,
853+ TmaWarpSpecializedGroupedGemmInput layout_info1, TmaWarpSpecializedGroupedGemmInput layout_info2,
854+ int64_t num_tokens, int64_t expanded_num_tokens, int64_t gemm1_n, int64_t gemm1_k, int64_t gemm2_n,
855+ int64_t gemm2_k, int const num_experts_per_node, T const * gemm1_in, T const * gemm2_in,
856+ WeightType const * weights1, WeightType const * weights2, float const * alpha_scale_flat1,
857+ float const * alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const * fp4_act_flat1,
816858 TmaWarpSpecializedGroupedGemmInput::ElementSF const * fp4_act_flat2, QuantParams quant_params,
817859 ScaleBiasType const * bias1, ScaleBiasType const * bias2, UnfusedGemmOutputType* gemm1_output,
818860 UnfusedGemmOutputType* gemm2_output, float const * router_scales, int const * permuted_row_to_unpermuted_row,
819- cudaStream_t stream);
861+ TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type, cudaStream_t stream);
820862 static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
821863 computeStridesTmaWarpSpecializedLowLatency (TmaWarpSpecializedGroupedGemmInput layout_info1,
822864 TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k,
@@ -866,11 +908,10 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
866908 return RunnerType::supportsTmaWarpSpecialized (sm) && sm >= 90 && !use_wfp4a16;
867909 }
868910
869- // TODO: This should eventually take the quant params to give more flexibility
870- // Instance method (not static) because the <e4m3, e4m3> template
871- // instantiation serves BOTH per-tensor FP8 and MXFP8xMXFP8 paths, with
872- // selection driven by the runtime flag use_mxfp8_weight_scaling_.
873- auto getScalingType () const
911+ // TODO: This should eventually take the full quant params to give more
912+ // flexibility. For now the only runtime selector is the MXFP8 flag for
913+ // the <e4m3, e4m3> instantiation (per-tensor FP8 vs MXFP8 block-scaled).
914+ static auto getScalingType (bool use_mxfp8_weight_scaling)
874915 {
875916 if constexpr (use_wfp4afp8)
876917 {
@@ -883,8 +924,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
883924 else if constexpr (use_fp8 && std::is_same_v<T, WeightType>)
884925 {
885926 // <e4m3, e4m3>: per-tensor FP8 (NONE) or MXFP8 block-scaled (MXFPX).
886- return use_mxfp8_weight_scaling_ ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX
887- : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE ;
927+ return use_mxfp8_weight_scaling ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX
928+ : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE ;
888929 }
889930 else
890931 {
@@ -990,7 +1031,7 @@ struct GemmProfilerBackend
9901031 nvinfer1::DataType wtype, nvinfer1::DataType otype, int num_experts, int k, int64_t hidden_size,
9911032 int64_t unpadded_hidden_size, int64_t inter_size, int64_t group_size, ActivationType activation_type, bool bias,
9921033 bool use_lora, bool min_latency_mode, bool need_weights, MOEParallelismConfig parallelism_config,
993- bool const enable_alltoall)
1034+ bool const enable_alltoall, bool use_mxfp8_weight_scaling = false )
9941035 {
9951036 mInterface = &runner;
9961037 mGemmToProfile = gemm_to_profile;
@@ -1011,6 +1052,7 @@ struct GemmProfilerBackend
10111052 mNeedWeights = need_weights;
10121053 mParallelismConfig = parallelism_config;
10131054 mEnableAlltoall = enable_alltoall;
1055+ mUseMxfp8WeightScaling = use_mxfp8_weight_scaling;
10141056 mSM = common::getSMVersion ();
10151057
10161058 mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE ;
@@ -1019,6 +1061,13 @@ struct GemmProfilerBackend
10191061 {
10201062 mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX ;
10211063 }
1064+ else if (dtype == nvinfer1::DataType::kFP8 && wtype == nvinfer1::DataType::kFP8 && use_mxfp8_weight_scaling)
1065+ {
1066+ // MXFP8 W8A8: e4m3 acts × e4m3 weights with UE8M0 1x32 block scales on both sides.
1067+ // Profiler must produce MXFPX block-scaled inputs (otherwise the per-expert SF
1068+ // pointer arrays stay uninitialized and the kernel reads garbage SF addresses).
1069+ mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX ;
1070+ }
10221071 else if ((dtype == nvinfer1::DataType::kFP4 || dtype == nvinfer1::DataType::kINT64 )
10231072 && (wtype == nvinfer1::DataType::kFP4 || wtype == nvinfer1::DataType::kINT64 ))
10241073 {
@@ -1048,6 +1097,7 @@ struct GemmProfilerBackend
10481097 ActivationType mActivationType {};
10491098 MOEParallelismConfig mParallelismConfig {};
10501099 bool mEnableAlltoall = false ;
1100+ bool mUseMxfp8WeightScaling = false ;
10511101
10521102 int mSampleIndex = 0 ;
10531103
0 commit comments