@@ -770,56 +770,38 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
770770 bool const use_mxfp8 = is_wfp8afp8
771771 && hopper_inputs.fpX_block_scaling_type
772772 == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX ;
773+ // Pick the IsMXFPX template parameter for a given FUSION, factoring out the duplicated
774+ // is_wfp4afp8 / is_wfp8afp8 / else chain. C++17-compatible via an integral_constant tag.
775+ auto select_mxfpx_mode = [&](auto fusion_tag)
776+ {
777+ constexpr auto FUSION = decltype (fusion_tag)::value;
778+ if constexpr (is_wfp4afp8)
779+ {
780+ return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
781+ OutputType, EpilogueTag, FUSION , true >;
782+ }
783+ else if constexpr (is_wfp8afp8)
784+ {
785+ return use_mxfp8 ? &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
786+ WeightType, OutputType, EpilogueTag, FUSION , true >
787+ : &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
788+ WeightType, OutputType, EpilogueTag, FUSION , false >;
789+ }
790+ else
791+ {
792+ return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
793+ OutputType, EpilogueTag, FUSION , false >;
794+ }
795+ };
773796 auto select_function = [&]()
774797 {
798+ using Fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
775799 switch (hopper_inputs.fusion )
776800 {
777- case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE :
778- if constexpr (is_wfp4afp8)
779- {
780- return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
781- OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE ,
782- true >;
783- }
784- else if constexpr (is_wfp8afp8)
785- {
786- return use_mxfp8 ? &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
787- WeightType, OutputType, EpilogueTag,
788- TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE , true >
789- : &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
790- WeightType, OutputType, EpilogueTag,
791- TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE , false >;
792- }
793- else
794- {
795- return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
796- OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE ,
797- false >;
798- }
799- case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE :
800- if constexpr (is_wfp4afp8)
801- {
802- return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
803- OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE ,
804- true >;
805- }
806- else if constexpr (is_wfp8afp8)
807- {
808- return use_mxfp8 ? &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
809- WeightType, OutputType, EpilogueTag,
810- TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE , true >
811- : &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
812- WeightType, OutputType, EpilogueTag,
813- TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE , false >;
814- }
815- else
816- {
817- return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
818- OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE ,
819- false >;
820- }
821- case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::ACTIVATION :
822- case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION :
801+ case Fusion::FINALIZE : return select_mxfpx_mode (std::integral_constant<Fusion, Fusion::FINALIZE >{});
802+ case Fusion::NONE : return select_mxfpx_mode (std::integral_constant<Fusion, Fusion::NONE >{});
803+ case Fusion::ACTIVATION :
804+ case Fusion::GATED_ACTIVATION :
823805 default : TLLM_THROW (" Unimplemented fusion %d requested" , (int ) hopper_inputs.fusion );
824806 };
825807 };
@@ -923,10 +905,13 @@ template <typename T, typename WeightType, typename OutputType, typename ScaleBi
923905size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getMaxWorkspaceSize(
924906 int num_experts, bool use_mxfp8_weight_scaling) const
925907{
926- if (num_experts != num_experts_)
908+ if (num_experts != num_experts_ || use_mxfp8_weight_scaling != use_mxfp8_weight_scaling_ )
927909 {
928- TLLM_LOG_TRACE (" Calling getMaxWorkspaceSize() with a new expert count %d vs %d" , num_experts, num_experts_);
910+ TLLM_LOG_TRACE (
911+ " Calling getMaxWorkspaceSize() with a new (expert count, use_mxfp8_weight_scaling) (%d, %d) vs (%d, %d)" ,
912+ num_experts, (int ) use_mxfp8_weight_scaling, num_experts_, (int ) use_mxfp8_weight_scaling_);
929913 num_experts_ = num_experts;
914+ use_mxfp8_weight_scaling_ = use_mxfp8_weight_scaling;
930915 gemm_workspace_size_ = calcMaxWorkspaceSize (num_experts, use_mxfp8_weight_scaling);
931916 }
932917 return gemm_workspace_size_;
0 commit comments