Skip to content

Commit a46a1f0

Browse files
committed
Resolve some new comments
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
1 parent ea7d3e2 commit a46a1f0

4 files changed

Lines changed: 76 additions & 72 deletions

File tree

cpp/tensorrt_llm/kernels/cutlass_kernels/fp4_gemm/mxfp8_mxfp4_gemm_template_sm100.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ struct MXSMTypeAdapter<__2SM>
8484
using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100;
8585
};
8686

87+
namespace detail
88+
{
89+
template <typename T, typename = void>
90+
struct has_bias_ptr : std::false_type
91+
{
92+
};
93+
94+
template <typename T>
95+
struct has_bias_ptr<T, std::void_t<decltype(std::declval<T&>().bias_ptr)>> : std::true_type
96+
{
97+
};
98+
} // namespace detail
99+
87100
#ifdef PLACEHOLDER_KERNELS
88101

89102
template <typename T, typename CTA_M, typename CTA_N, typename CTA_K, typename CGA_M, typename CGA_N, typename CGA_K,
@@ -187,7 +200,10 @@ typename Gemm::Arguments prepareGemmArgsSm100(void* D, void const* A, void const
187200
operator_args.mode = cutlass::gemm::GemmUniversalMode::kGemm;
188201
auto& fusion_args = operator_args.epilogue.thread;
189202
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(global_sf);
190-
fusion_args.bias_ptr = static_cast<ElementD const*>(bias);
203+
if constexpr (detail::has_bias_ptr<std::decay_t<decltype(fusion_args)>>::value)
204+
{
205+
fusion_args.bias_ptr = static_cast<ElementD const*>(bias);
206+
}
191207

192208
operator_args.problem_shape = cute::make_shape(m, n, k, batch_count);
193209

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ class MoeGemmRunner
335335
int sm_{};
336336
int multi_processor_count_{};
337337
mutable int num_experts_ = 0;
338+
mutable bool use_mxfp8_weight_scaling_ = false;
338339
mutable size_t gemm_workspace_size_ = 0;
339340
size_t calcMaxWorkspaceSize(int num_experts, bool use_mxfp8_weight_scaling) const;
340341
};

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h

Lines changed: 33 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -770,56 +770,38 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::dispatchToArch(
770770
bool const use_mxfp8 = is_wfp8afp8
771771
&& hopper_inputs.fpX_block_scaling_type
772772
== TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX;
773+
// Pick the IsMXFPX template parameter for a given FUSION, factoring out the duplicated
774+
// is_wfp4afp8 / is_wfp8afp8 / else chain. C++17-compatible via an integral_constant tag.
775+
auto select_mxfpx_mode = [&](auto fusion_tag)
776+
{
777+
constexpr auto FUSION = decltype(fusion_tag)::value;
778+
if constexpr (is_wfp4afp8)
779+
{
780+
return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
781+
OutputType, EpilogueTag, FUSION, true>;
782+
}
783+
else if constexpr (is_wfp8afp8)
784+
{
785+
return use_mxfp8 ? &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
786+
WeightType, OutputType, EpilogueTag, FUSION, true>
787+
: &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
788+
WeightType, OutputType, EpilogueTag, FUSION, false>;
789+
}
790+
else
791+
{
792+
return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
793+
OutputType, EpilogueTag, FUSION, false>;
794+
}
795+
};
773796
auto select_function = [&]()
774797
{
798+
using Fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion;
775799
switch (hopper_inputs.fusion)
776800
{
777-
case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE:
778-
if constexpr (is_wfp4afp8)
779-
{
780-
return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
781-
OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE,
782-
true>;
783-
}
784-
else if constexpr (is_wfp8afp8)
785-
{
786-
return use_mxfp8 ? &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
787-
WeightType, OutputType, EpilogueTag,
788-
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE, true>
789-
: &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
790-
WeightType, OutputType, EpilogueTag,
791-
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE, false>;
792-
}
793-
else
794-
{
795-
return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
796-
OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE,
797-
false>;
798-
}
799-
case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE:
800-
if constexpr (is_wfp4afp8)
801-
{
802-
return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
803-
OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE,
804-
true>;
805-
}
806-
else if constexpr (is_wfp8afp8)
807-
{
808-
return use_mxfp8 ? &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
809-
WeightType, OutputType, EpilogueTag,
810-
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, true>
811-
: &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T,
812-
WeightType, OutputType, EpilogueTag,
813-
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, false>;
814-
}
815-
else
816-
{
817-
return &cutlass_kernels_oss::dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType,
818-
OutputType, EpilogueTag, TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE,
819-
false>;
820-
}
821-
case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::ACTIVATION:
822-
case TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION:
801+
case Fusion::FINALIZE: return select_mxfpx_mode(std::integral_constant<Fusion, Fusion::FINALIZE>{});
802+
case Fusion::NONE: return select_mxfpx_mode(std::integral_constant<Fusion, Fusion::NONE>{});
803+
case Fusion::ACTIVATION:
804+
case Fusion::GATED_ACTIVATION:
823805
default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_inputs.fusion);
824806
};
825807
};
@@ -923,10 +905,13 @@ template <typename T, typename WeightType, typename OutputType, typename ScaleBi
923905
size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getMaxWorkspaceSize(
924906
int num_experts, bool use_mxfp8_weight_scaling) const
925907
{
926-
if (num_experts != num_experts_)
908+
if (num_experts != num_experts_ || use_mxfp8_weight_scaling != use_mxfp8_weight_scaling_)
927909
{
928-
TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_);
910+
TLLM_LOG_TRACE(
911+
"Calling getMaxWorkspaceSize() with a new (expert count, use_mxfp8_weight_scaling) (%d, %d) vs (%d, %d)",
912+
num_experts, (int) use_mxfp8_weight_scaling, num_experts_, (int) use_mxfp8_weight_scaling_);
929913
num_experts_ = num_experts;
914+
use_mxfp8_weight_scaling_ = use_mxfp8_weight_scaling;
930915
gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts, use_mxfp8_weight_scaling);
931916
}
932917
return gemm_workspace_size_;

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,13 @@ void dispatchMoeGemmSelectTileShapeTmaWarpSpecialized(TmaWarpSpecializedGroupedG
463463

464464
if (gemm_config.sm_version == 90)
465465
{
466-
if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag, FUSION>())
466+
// Block-scaled MXFP8xMXFP8 (IsMXFPX=true) is Blackwell-only; the SM90 launcher
467+
// has no `is_mx_fpx=True` explicit instantiation in generate_kernels.py. Gate
468+
// the SM90 dispatch on `!IsMXFPX` so the IsMXFPX=true template is never
469+
// instantiated for Sm90 (otherwise the link of libth_common.so fails with
470+
// undefined references when SM90 is included in CMAKE_CUDA_ARCHITECTURES).
471+
if constexpr (!IsMXFPX
472+
&& kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType, EpilogueTag, FUSION>())
467473
{
468474
switch (gemm_config.tile_config_sm90)
469475
{
@@ -558,34 +564,30 @@ size_t calcMaxWorkspaceSizeTmaWarpSpecialized(int num_experts, cutlass_extension
558564
// <e4m3, e4m3> needs the IsMXFPX template to match what the runtime dispatch will pick.
559565
constexpr bool is_wfp4afp8 = std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp4_e2m1>;
560566
constexpr bool is_wfp8afp8 = std::is_same_v<T, __nv_fp8_e4m3> && std::is_same_v<WeightType, __nv_fp8_e4m3>;
561-
if constexpr (is_wfp4afp8)
567+
auto pick_kernel = [&]()
562568
{
563-
dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType, OutputType,
564-
cutlass_extensions::EpilogueOpDefault, FUSION, true>(
565-
input, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count);
566-
}
567-
else if constexpr (is_wfp8afp8)
568-
{
569-
bool const use_mxfp8 = fpX_block_scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX;
570-
if (use_mxfp8)
569+
if constexpr (is_wfp4afp8)
570+
{
571+
return &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType, OutputType,
572+
cutlass_extensions::EpilogueOpDefault, FUSION, true>;
573+
}
574+
else if constexpr (is_wfp8afp8)
571575
{
572-
dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType, OutputType,
573-
cutlass_extensions::EpilogueOpDefault, FUSION, true>(
574-
input, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count);
576+
bool const use_mxfp8
577+
= fpX_block_scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX;
578+
return use_mxfp8 ? &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType, OutputType,
579+
cutlass_extensions::EpilogueOpDefault, FUSION, true>
580+
: &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType, OutputType,
581+
cutlass_extensions::EpilogueOpDefault, FUSION, false>;
575582
}
576583
else
577584
{
578-
dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType, OutputType,
579-
cutlass_extensions::EpilogueOpDefault, FUSION, false>(
580-
input, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count);
585+
return &dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType, OutputType,
586+
cutlass_extensions::EpilogueOpDefault, FUSION, false>;
581587
}
582-
}
583-
else
584-
{
585-
dispatchMoeGemmSelectTileShapeTmaWarpSpecialized<T, WeightType, OutputType,
586-
cutlass_extensions::EpilogueOpDefault, FUSION, false>(
587-
input, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count);
588-
}
588+
};
589+
auto selected_kernel = pick_kernel();
590+
selected_kernel(input, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count);
589591
return count;
590592
}
591593

0 commit comments

Comments
 (0)