Skip to content

Commit 7ac1c41

Browse files
committed
Rebase main branch and fix some issues
Signed-off-by: Fred Wei <20514172+WeiHaocheng@users.noreply.github.com>
1 parent 6f6aaf8 commit 7ac1c41

15 files changed

Lines changed: 422 additions & 218 deletions

File tree

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ class MoeGemmRunner
318318
ActivationType activation_type, int gemm_n, int gemm_k) const;
319319
[[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const;
320320

321-
size_t getMaxWorkspaceSize(int num_experts) const;
321+
size_t getMaxWorkspaceSize(int num_experts, bool use_mxfp8_weight_scaling = false) const;
322322

323323
[[nodiscard]] int getSM() const;
324324

@@ -336,7 +336,7 @@ class MoeGemmRunner
336336
int multi_processor_count_{};
337337
mutable int num_experts_ = 0;
338338
mutable size_t gemm_workspace_size_ = 0;
339-
size_t calcMaxWorkspaceSize(int num_experts) const;
339+
size_t calcMaxWorkspaceSize(int num_experts, bool use_mxfp8_weight_scaling) const;
340340
};
341341

342342
} // namespace kernels::cutlass_kernels

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,22 @@ struct QuantParams
301301
GemmInputs fc2;
302302
} mxfp8_mxfp4;
303303

304+
// MXFP8 x MXFP8 quantization params (W8A8 with UE8M0 1x32 block scales on both
305+
// sides). No per-tensor / global alpha (block scales determine output magnitude).
306+
// Kept as a separate slot from mxfp8_mxfp4 so consumers can disambiguate
307+
// B element bitwidth (4-bit vs 8-bit) without relying on naming aliases.
308+
struct MXFP8MXFP8Inputs
309+
{
310+
struct GemmInputs
311+
{
312+
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale
313+
= nullptr; // (experts, n, k / 32)
314+
};
315+
316+
GemmInputs fc1;
317+
GemmInputs fc2;
318+
} mxfp8_mxfp8;
319+
304320
// FP4 quantization params
305321
struct FP4Inputs
306322
{
@@ -404,6 +420,36 @@ struct QuantParams
404420
return qp;
405421
}
406422

423+
// MXFP8xMXFP8 grouped MoE: e4m3 acts + e4m3 weights, UE8M0 1x32 block
424+
// scales on both sides. No per-tensor / global alpha is required (block
425+
// scales determine output magnitude). Writes to its own dedicated slot
426+
// so consumers can distinguish from MXFP8xMXFP4 (which has 4-bit B).
427+
static QuantParams MXFP8MXFP8(TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc1_weight_block_scale,
428+
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale)
429+
{
430+
QuantParams qp;
431+
qp.mxfp8_mxfp8.fc1 = {fc1_weight_block_scale};
432+
qp.mxfp8_mxfp8.fc2 = {fc2_weight_block_scale};
433+
return qp;
434+
}
435+
436+
// Helpers: return the active MXFPX activation-side block-scale pointer
437+
// regardless of whether B is fp4 (mxfp8_mxfp4) or fp8 (mxfp8_mxfp8).
438+
// Used by consumers that only care "is the activation path block-scaled
439+
// MXFPX" — they don't need to know B's bitwidth, that's already encoded
440+
// in the kernel template instantiation.
441+
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* mxfpxActFc1WeightScale() const
442+
{
443+
return mxfp8_mxfp8.fc1.weight_block_scale ? mxfp8_mxfp8.fc1.weight_block_scale
444+
: mxfp8_mxfp4.fc1.weight_block_scale;
445+
}
446+
447+
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* mxfpxActFc2WeightScale() const
448+
{
449+
return mxfp8_mxfp8.fc2.weight_block_scale ? mxfp8_mxfp8.fc2.weight_block_scale
450+
: mxfp8_mxfp4.fc2.weight_block_scale;
451+
}
452+
407453
static QuantParams FP4(float const* fc1_act_global_scale,
408454
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc1_weight_block_scale,
409455
float const* fc1_global_scale, //
@@ -747,7 +793,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
747793

748794
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override
749795
{
750-
return moe_gemm_runner_.getMaxWorkspaceSize(num_experts_per_node);
796+
return moe_gemm_runner_.getMaxWorkspaceSize(num_experts_per_node, use_mxfp8_weight_scaling_);
751797
}
752798

753799
std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -769,7 +815,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
769815
reinterpret_cast<ScaleBiasType const*>(bias1), reinterpret_cast<ScaleBiasType const*>(bias2),
770816
reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
771817
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), router_scales, permuted_row_to_unpermuted_row,
772-
stream);
818+
Self::getScalingType(use_mxfp8_weight_scaling_), stream);
773819
}
774820

775821
std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -802,21 +848,17 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
802848
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool use_lora, int start_expert,
803849
MOEParallelismConfig parallelism_config, cudaStream_t stream);
804850

805-
// Non-static so it can read use_mxfp8_weight_scaling_ via getScalingType()
806-
// when picking the runtime FpXBlockScalingType for the <e4m3, e4m3>
807-
// template. The only caller (in moe_kernels.cu) is already inside an
808-
// instance method, so this is a no-op for existing call sites.
809-
std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput> computeStridesTmaWarpSpecialized(
810-
int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput layout_info1,
811-
TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t expanded_num_tokens,
812-
int64_t gemm1_n, int64_t gemm1_k, int64_t gemm2_n, int64_t gemm2_k, int const num_experts_per_node,
813-
T const* gemm1_in, T const* gemm2_in, WeightType const* weights1, WeightType const* weights2,
814-
float const* alpha_scale_flat1, float const* alpha_scale_flat2,
815-
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
851+
static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
852+
computeStridesTmaWarpSpecialized(int64_t const* expert_first_token_offset,
853+
TmaWarpSpecializedGroupedGemmInput layout_info1, TmaWarpSpecializedGroupedGemmInput layout_info2,
854+
int64_t num_tokens, int64_t expanded_num_tokens, int64_t gemm1_n, int64_t gemm1_k, int64_t gemm2_n,
855+
int64_t gemm2_k, int const num_experts_per_node, T const* gemm1_in, T const* gemm2_in,
856+
WeightType const* weights1, WeightType const* weights2, float const* alpha_scale_flat1,
857+
float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
816858
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params,
817859
ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output,
818860
UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row,
819-
cudaStream_t stream);
861+
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type, cudaStream_t stream);
820862
static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
821863
computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1,
822864
TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k,
@@ -866,11 +908,10 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
866908
return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_wfp4a16;
867909
}
868910

869-
// TODO: This should eventually take the quant params to give more flexibility
870-
// Instance method (not static) because the <e4m3, e4m3> template
871-
// instantiation serves BOTH per-tensor FP8 and MXFP8xMXFP8 paths, with
872-
// selection driven by the runtime flag use_mxfp8_weight_scaling_.
873-
auto getScalingType() const
911+
// TODO: This should eventually take the full quant params to give more
912+
// flexibility. For now the only runtime selector is the MXFP8 flag for
913+
// the <e4m3, e4m3> instantiation (per-tensor FP8 vs MXFP8 block-scaled).
914+
static auto getScalingType(bool use_mxfp8_weight_scaling)
874915
{
875916
if constexpr (use_wfp4afp8)
876917
{
@@ -883,8 +924,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
883924
else if constexpr (use_fp8 && std::is_same_v<T, WeightType>)
884925
{
885926
// <e4m3, e4m3>: per-tensor FP8 (NONE) or MXFP8 block-scaled (MXFPX).
886-
return use_mxfp8_weight_scaling_ ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX
887-
: TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
927+
return use_mxfp8_weight_scaling ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX
928+
: TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
888929
}
889930
else
890931
{
@@ -990,7 +1031,7 @@ struct GemmProfilerBackend
9901031
nvinfer1::DataType wtype, nvinfer1::DataType otype, int num_experts, int k, int64_t hidden_size,
9911032
int64_t unpadded_hidden_size, int64_t inter_size, int64_t group_size, ActivationType activation_type, bool bias,
9921033
bool use_lora, bool min_latency_mode, bool need_weights, MOEParallelismConfig parallelism_config,
993-
bool const enable_alltoall)
1034+
bool const enable_alltoall, bool use_mxfp8_weight_scaling = false)
9941035
{
9951036
mInterface = &runner;
9961037
mGemmToProfile = gemm_to_profile;
@@ -1011,6 +1052,7 @@ struct GemmProfilerBackend
10111052
mNeedWeights = need_weights;
10121053
mParallelismConfig = parallelism_config;
10131054
mEnableAlltoall = enable_alltoall;
1055+
mUseMxfp8WeightScaling = use_mxfp8_weight_scaling;
10141056
mSM = common::getSMVersion();
10151057

10161058
mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
@@ -1019,6 +1061,13 @@ struct GemmProfilerBackend
10191061
{
10201062
mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX;
10211063
}
1064+
else if (dtype == nvinfer1::DataType::kFP8 && wtype == nvinfer1::DataType::kFP8 && use_mxfp8_weight_scaling)
1065+
{
1066+
// MXFP8 W8A8: e4m3 acts × e4m3 weights with UE8M0 1x32 block scales on both sides.
1067+
// Profiler must produce MXFPX block-scaled inputs (otherwise the per-expert SF
1068+
// pointer arrays stay uninitialized and the kernel reads garbage SF addresses).
1069+
mScalingType = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX;
1070+
}
10221071
else if ((dtype == nvinfer1::DataType::kFP4 || dtype == nvinfer1::DataType::kINT64)
10231072
&& (wtype == nvinfer1::DataType::kFP4 || wtype == nvinfer1::DataType::kINT64))
10241073
{
@@ -1048,6 +1097,7 @@ struct GemmProfilerBackend
10481097
ActivationType mActivationType{};
10491098
MOEParallelismConfig mParallelismConfig{};
10501099
bool mEnableAlltoall = false;
1100+
bool mUseMxfp8WeightScaling = false;
10511101

10521102
int mSampleIndex = 0;
10531103

0 commit comments

Comments
 (0)