@@ -483,17 +483,17 @@ namespace kernels::cutlass_kernels
483483
484484template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
485485std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs(
486- bool supports_finalize_fusion) const
486+ bool supports_finalize_fusion, bool use_mxfp8 ) const
487487{
488- return getConfigs (sm_, supports_finalize_fusion);
488+ return getConfigs (sm_, supports_finalize_fusion, use_mxfp8 );
489489}
490490
491491template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
492492std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs(
493- int sm, bool supports_finalize_fusion)
493+ int sm, bool supports_finalize_fusion, bool use_mxfp8 )
494494{
495495 std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs
496- = getTmaWarpSpecializedConfigs (sm, supports_finalize_fusion);
496+ = getTmaWarpSpecializedConfigs (sm, supports_finalize_fusion, use_mxfp8 );
497497 std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs = getAmpereConfigs (sm);
498498 std::copy (ampere_configs.begin (), ampere_configs.end (), std::back_inserter (candidate_configs));
499499 return candidate_configs;
@@ -530,7 +530,7 @@ MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getAmpereConfigs(int sm
530530template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
531531std::vector<cutlass_extensions::CutlassGemmConfig>
532532MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getTmaWarpSpecializedConfigs(
533- int sm, bool supports_finalize_fusion)
533+ int sm, bool supports_finalize_fusion, bool use_mxfp8 )
534534{
535535 using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
536536 static constexpr auto weight_only_flag
@@ -545,8 +545,16 @@ MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getTmaWarpSpecializedCo
545545 static constexpr auto fp4_only_flag
546546 = (use_fp4 || use_wfp4afp8) ? CutlassGemmConfig::FP4_ONLY : CutlassGemmConfig::NONE ;
547547 static constexpr auto fp8fp4_mixed_flag = use_wfp4afp8 ? CutlassGemmConfig::FP8FP4_MIXED : CutlassGemmConfig::NONE ;
548- auto config_type_param = static_cast <CutlassGemmConfig::CandidateConfigTypeParam>(weight_only_flag | simt_only_flag
549- | grouped_gemm_flag | enable_blackwell | enable_hopper | fp8_only_flag | fp4_only_flag | fp8fp4_mixed_flag);
548+ // MXFP8xMXFP8 only applies to <e4m3, e4m3>; for other type pairs the flag is ignored.
549+ #if defined(ENABLE_FP8)
550+ static constexpr bool is_wfp8afp8 = std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp8_e4m3>;
551+ #else
552+ static constexpr bool is_wfp8afp8 = false ;
553+ #endif
554+ int const mxfp8_flag = (use_mxfp8 && is_wfp8afp8) ? CutlassGemmConfig::MXFP8_MXFP8 : CutlassGemmConfig::NONE ;
555+ auto config_type_param
556+ = static_cast <CutlassGemmConfig::CandidateConfigTypeParam>(weight_only_flag | simt_only_flag | grouped_gemm_flag
557+ | enable_blackwell | enable_hopper | fp8_only_flag | fp4_only_flag | fp8fp4_mixed_flag | mxfp8_flag);
550558 TLLM_CHECK_WITH_INFO (!(enable_blackwell && enable_hopper), " Blackwell and hopper flags are mutually exclusive" );
551559
552560 sm = use_wfp4afp8 && sm == 103 ? 100 : sm;
@@ -770,56 +778,38 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
770778 bool const use_mxfp8 = is_wfp8afp8
771779 && hopper_inputs.fpX_block_scaling_type
772780 == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX ;
781+ // Pick the IsMXFPX template parameter for a given FUSION, factoring out the duplicated
782+ // is_wfp4afp8 / is_wfp8afp8 / else chain. C++17-compatible via an integral_constant tag.
783+ auto select_mxfpx_mode = [&](auto fusion_tag)
784+ {
785+ constexpr auto FUSION = decltype (fusion_tag)::value;
786+ if constexpr (is_wfp4afp8)
787+ {
788+ return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
789+ OutputType, EpilogueTag, FUSION , true >;
790+ }
791+ else if constexpr (is_wfp8afp8)
792+ {
793+ return use_mxfp8 ? &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
794+ WeightType, OutputType, EpilogueTag, FUSION , true >
795+ : &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
796+ WeightType, OutputType, EpilogueTag, FUSION , false >;
797+ }
798+ else
799+ {
800+ return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
801+ OutputType, EpilogueTag, FUSION , false >;
802+ }
803+ };
773804 auto select_function = [&]()
774805 {
806+ using Fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
775807 switch (hopper_inputs.fusion )
776808 {
777- case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE :
778- if constexpr (is_wfp4afp8)
779- {
780- return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
781- OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE ,
782- true >;
783- }
784- else if constexpr (is_wfp8afp8)
785- {
786- return use_mxfp8 ? &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
787- WeightType, OutputType, EpilogueTag,
788- TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE , true >
789- : &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
790- WeightType, OutputType, EpilogueTag,
791- TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE , false >;
792- }
793- else
794- {
795- return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
796- OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE ,
797- false >;
798- }
799- case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE :
800- if constexpr (is_wfp4afp8)
801- {
802- return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
803- OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE ,
804- true >;
805- }
806- else if constexpr (is_wfp8afp8)
807- {
808- return use_mxfp8 ? &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
809- WeightType, OutputType, EpilogueTag,
810- TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE , true >
811- : &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
812- WeightType, OutputType, EpilogueTag,
813- TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE , false >;
814- }
815- else
816- {
817- return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
818- OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE ,
819- false >;
820- }
821- case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::ACTIVATION :
822- case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION :
809+ case Fusion::FINALIZE : return select_mxfpx_mode (std::integral_constant<Fusion, Fusion::FINALIZE >{});
810+ case Fusion::NONE : return select_mxfpx_mode (std::integral_constant<Fusion, Fusion::NONE >{});
811+ case Fusion::ACTIVATION :
812+ case Fusion::GATED_ACTIVATION :
823813 default : TLLM_THROW (" Unimplemented fusion %d requested" , (int ) hopper_inputs.fusion );
824814 };
825815 };
@@ -923,10 +913,13 @@ template <typename T, typename WeightType, typename OutputType, typename ScaleBi
923913size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getMaxWorkspaceSize(
924914 int num_experts, bool use_mxfp8_weight_scaling) const
925915{
926- if (num_experts != num_experts_)
916+ if (num_experts != num_experts_ || use_mxfp8_weight_scaling != use_mxfp8_weight_scaling_ )
927917 {
928- TLLM_LOG_TRACE (" Calling getMaxWorkspaceSize() with a new expert count %d vs %d" , num_experts, num_experts_);
918+ TLLM_LOG_TRACE (
919+ " Calling getMaxWorkspaceSize() with a new (expert count, use_mxfp8_weight_scaling) (%d, %d) vs (%d, %d)" ,
920+ num_experts, (int ) use_mxfp8_weight_scaling, num_experts_, (int ) use_mxfp8_weight_scaling_);
929921 num_experts_ = num_experts;
922+ use_mxfp8_weight_scaling_ = use_mxfp8_weight_scaling;
930923 gemm_workspace_size_ = calcMaxWorkspaceSize (num_experts, use_mxfp8_weight_scaling);
931924 }
932925 return gemm_workspace_size_;
@@ -949,8 +942,11 @@ size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::calcMaxWorkspace
949942 && !use_w4afp8 && !use_wfp4a16)
950943 {
951944 // Finalize fusion may not actually be supported by the kernel,
952- // if they are not we will catch the error and skip them
953- auto configs = getTmaWarpSpecializedConfigs (sm_, true );
945+ // if they are not we will catch the error and skip them. Restrict the
946+ // candidate set to MXFP8-valid tiles when the caller is sizing for the
947+ // MXFP8xMXFP8 variant; otherwise the FP8 list would include tiles the
948+ // dispatcher rejects.
949+ auto configs = getTmaWarpSpecializedConfigs (sm_, true , use_mxfp8_weight_scaling);
954950 // For <e4m3, e4m3> the same template compiles both per-tensor FP8
955951 // (NONE) and MXFP8 block-scaled (MXFPX) variants; the caller passes
956952 // `use_mxfp8_weight_scaling` so we size workspace for exactly the
0 commit comments