Skip to content

Commit ae2dc3d

Browse files
authored
[None][feat] Add silu to trtllm-gen MoE (#11663)
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
1 parent 7a68c42 commit ae2dc3d

1,103 files changed

Lines changed: 14910 additions & 5859 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ enum class ActType
4242
//
4343
// GatedSilu is a special case of SwiGlu where the alpha is 1.0 and the beta is 0.0.
4444
SwiGlu,
45-
Relu2
45+
Relu2,
46+
Silu
4647
};
4748

4849
// Type of the element-wise activation to apply after the Gemm
@@ -59,6 +60,10 @@ enum class EltwiseActType
5960
// act = relu(x0) ^ 2
6061
// where x0 is the output of the Gemm.
6162
Relu2,
63+
// Silu is defined as the following operation:
64+
// act = x0 * sigmoid(x0)
65+
// where x0 is the output of the Gemm.
66+
Silu
6267
};
6368

6469
struct TrtllmGenBatchedGemmRunnerOptions

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,10 @@ struct BatchedGemmData
141141
// The rightmost dimension is contiguous in memory.
142142
//
143143
// If DeepSeek FP8 recipe is not used, but for MxFp{4,8}, MxInt4 and NvFp4 formats:
144-
// The layout of scaling factors for A is always R128c4
144+
// If the layout is R128c4,
145145
// M must be a multiple of 128.
146-
// K must be a multiple of 64.
147-
// The "logical" shape is: [paddedM, K / P], where P is the scaling block size.
146+
// K must be a multiple of 4 * P, where P is the scaling block size.
147+
// The "logical" shape is: [paddedM, K / P].
148148
// The R128c4 layout is: [paddedM / 128, K / P / 4, 512].
149149
// The shape we use for TMA is: [paddedM / 128, K / P / 4, 2, 256].
150150
// Where paddedM is M if (routeAct == true && batchM), or
@@ -302,7 +302,7 @@ struct BatchedGemmData
302302

303303
// The pre-activation scaling factor (typically dequantA * dequantB) for non-gated non-linear
304304
// activation.
305-
// Only used when non-linear activation is applied (e.g., GELU, Relu2).
305+
// Only used when non-linear activation is applied (e.g., GELU, Relu2, Silu).
306306
// When used, scaleC should be quantScaleC only, and this scale is applied before the
307307
// activation. Shape is [B].
308308
float const* mPtrScaleAct{nullptr};
@@ -786,7 +786,7 @@ class BatchedGemmInterface
786786
{
787787
numCtasBatch += batchM
788788
? gemm::divUp(options.mBatchedM[bi], options.mTileM * options.mClusterDimX) * options.mClusterDimX
789-
: gemm::divUp(options.mBatchedN[bi], options.mTileN);
789+
: gemm::divUp(options.mBatchedN[bi], options.mTileN * options.mClusterDimY) * options.mClusterDimY;
790790
}
791791
}
792792
// For MoE, mNumTokens != 0 and the number of CTAs is known only at runtime.
@@ -923,19 +923,21 @@ class BatchedGemmInterface
923923
{
924924
totalNumPaddedTokens += batchM
925925
? gemm::divUpMul(options.mBatchedM[bi], options.mTileM * options.mClusterDimX)
926-
: gemm::divUpMul(options.mBatchedN[bi], options.mTileN);
926+
: gemm::divUpMul(options.mBatchedN[bi], options.mTileN * options.mClusterDimY);
927927
}
928928
}
929929
else
930930
{
931931
// Get tile in token dim.
932-
auto tileTokensDim = batchM ? options.mTileM * options.mClusterDimX : options.mTileN;
932+
auto tileTokensDim
933+
= batchM ? options.mTileM * options.mClusterDimX : options.mTileN * options.mClusterDimY;
933934
totalNumPaddedTokens = data.mProblemDimensions.mMaxNumCtasInTokenDim * tileTokensDim;
934935
}
935936
// Get options from config.
936937
auto& options = config.mOptions;
937938

938-
int const tokenTile = batchM ? options.mTileM * options.mClusterDimX : options.mTileN;
939+
int const tokenTile
940+
= batchM ? options.mTileM * options.mClusterDimX : options.mTileN * options.mClusterDimY;
939941

940942
auto const numTokens = totalNumPaddedTokens;
941943
auto const intermediateDim = batchM ? options.mN : options.mM;

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmOptions.h

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,18 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
100100
tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB,
101101
gemm::EltwiseActType eltwiseActType, bool enablesEarlyExit, bool enablesDelayedEarlyExit,
102102
bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, int epilogueTileN,
103-
bool fuseUtccpWithUtcmma, bool gridTriggerSecondaryA, bool gridTriggerSecondaryB,
104-
bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit,
105-
bool hoistMmaTaskTryWaits, int k, gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA,
106-
gemm::MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n,
107-
int numEpilogueWarps, int numRegsCastAWarps, int numRegsCopySfLdsSttm, int numRegsCopySparsityInfo,
108-
int numRegsPerThreadEpilogueWarp, int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK,
109-
int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile,
110-
int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp,
111-
int32_t sfBlockSizeA, int32_t sfBlockSizeB, int32_t sfBlockSizeC, tg::SfLayout sfLayoutA,
112-
tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK, tg::Sparsity sparsityA,
113-
gemm::SplitK splitK, int tileK, int tileM, int tileN, gemm::TileScheduler tileScheduler,
114-
bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8,
103+
int fallbackClusterDimX, int fallbackClusterDimY, int fallbackClusterDimZ, bool fuseUtccpWithUtcmma,
104+
bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit,
105+
bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k,
106+
gemm::KernelTraits kernelTraits, gemm::MatrixLayout layoutA, gemm::MatrixLayout layoutB, int m, int mmaK,
107+
tg::MmaKind mmaKind, int mmaM, int mmaN, bool mockAllReduce, int n, int numEpilogueWarps, int numRegsCastAWarps,
108+
int numRegsCopySfLdsSttm, int numRegsCopySparsityInfo, int numRegsPerThreadEpilogueWarp,
109+
int numRegsPerThreadNonEpilogueWarp, int numSlicesForSplitK, int numSlicesForSliceK, int numStages,
110+
int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId,
111+
bool outputDebugTensors, bool patchF2fp, int32_t sfBlockSizeA, int32_t sfBlockSizeB, int32_t sfBlockSizeC,
112+
tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, int32_t sfReshapeFactor, bool sliceK,
113+
tg::Sparsity sparsityA, gemm::SplitK splitK, int tileK, int tileM, int tileN, gemm::TileScheduler tileScheduler,
114+
bool transposeMmaOutput, bool useCustomMmaSchedule, bool useDeepSeekFp8, bool useFlexibleClusterDims,
115115
bool useHoistTryWaitForCustomMmaSchedule, bool useMaxTmemOverlap, bool usePerTokenSfA, bool usePerTokenSfB,
116116
bool useShuffledMatrix, bool useTmaStore, bool useTwoTmaLoadWarps, bool useTwoMmaWarps,
117117
bool useUnrollLoop2xForMma, int validM, int validN, int validK, int worldSize,
@@ -127,17 +127,18 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
127127
gemm::GemmOptions(allReduceAlgo, biasType, blockK, clcFastDrain, clusterDimX, clusterDimY, clusterDimZ,
128128
ctaSwizzleType, dtypeAcc, dtypeA, dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, eltwiseActType,
129129
enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs, epilogueLdtmDps, epilogueLdtmBits,
130-
epilogueTileM, epilogueTileN, fuseUtccpWithUtcmma, gridTriggerSecondaryA, gridTriggerSecondaryB,
131-
gridWaitForPrimaryEarlyExit, gridWaitForPrimaryA, gridWaitForPrimaryB, hoistLoadTaskInit,
132-
hoistMmaTaskTryWaits, k, kernelTraits, layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n,
133-
numEpilogueWarps, numRegsCastAWarps, numRegsCopySfLdsSttm, numRegsCopySparsityInfo,
134-
numRegsPerThreadEpilogueWarp, numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, numSlicesForSliceK,
135-
numStages, numStagesMma, numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId,
136-
outputDebugTensors, patchF2fp, sfBlockSizeA, sfBlockSizeB, sfBlockSizeC, sfLayoutA, sfLayoutB,
137-
sfLayoutC, sfReshapeFactor, sliceK, sparsityA, splitK, tileK, tileM, tileN, tileScheduler,
138-
transposeMmaOutput, useCustomMmaSchedule, useDeepSeekFp8, useHoistTryWaitForCustomMmaSchedule,
139-
useMaxTmemOverlap, usePerTokenSfA, usePerTokenSfB, useShuffledMatrix, useTmaStore, useTwoTmaLoadWarps,
140-
useTwoMmaWarps, useUnrollLoop2xForMma, validM, validN, validK, worldSize),
130+
epilogueTileM, epilogueTileN, fallbackClusterDimX, fallbackClusterDimY, fallbackClusterDimZ,
131+
fuseUtccpWithUtcmma, gridTriggerSecondaryA, gridTriggerSecondaryB, gridWaitForPrimaryEarlyExit,
132+
gridWaitForPrimaryA, gridWaitForPrimaryB, hoistLoadTaskInit, hoistMmaTaskTryWaits, k, kernelTraits,
133+
layoutA, layoutB, m, mmaK, mmaKind, mmaM, mmaN, mockAllReduce, n, numEpilogueWarps, numRegsCastAWarps,
134+
numRegsCopySfLdsSttm, numRegsCopySparsityInfo, numRegsPerThreadEpilogueWarp,
135+
numRegsPerThreadNonEpilogueWarp, numSlicesForSplitK, numSlicesForSliceK, numStages, numStagesMma,
136+
numStagesMmaWithinWorkTile, numStagesMmaAcrossWorkTile, numStagesWorkId, outputDebugTensors, patchF2fp,
137+
sfBlockSizeA, sfBlockSizeB, sfBlockSizeC, sfLayoutA, sfLayoutB, sfLayoutC, sfReshapeFactor, sliceK,
138+
sparsityA, splitK, tileK, tileM, tileN, tileScheduler, transposeMmaOutput, useCustomMmaSchedule,
139+
useDeepSeekFp8, useFlexibleClusterDims, useHoistTryWaitForCustomMmaSchedule, useMaxTmemOverlap,
140+
usePerTokenSfA, usePerTokenSfB, useShuffledMatrix, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps,
141+
useUnrollLoop2xForMma, validM, validN, validK, worldSize),
141142
actType, clampBeforeAct)
142143
, mBatchedM(batchedM)
143144
, mBatchedN(batchedN)
@@ -310,7 +311,7 @@ inline bool checkAndUpdateBatchedGemmOptions(
310311
TLLM_CHECK_ERROR((options.mRouteSfsImpl.value() == RouteImpl::Ldgsts
311312
|| options.mRouteSfsImpl.value() == RouteImpl::LdgPlusSts)
312313
&& options.mRouteImpl == RouteImpl::Tma,
313-
"RouteSfsImpl must be equal to RouteImpl, or Ldgsts/LdgPlusSts, when RouteImpl is Tma");
314+
"RouteSfsImpl must be equal to RouteImpl, or Ldgsts/LdgPlusSts when RouteImpl is Tma");
314315
}
315316
else if (!options.mRouteSfsImpl.has_value())
316317
{
@@ -379,8 +380,6 @@ inline bool checkAndUpdateBatchedGemmOptions(
379380

380381
if (doesRouteImplUseTma(options.mRouteSfsImpl.value()))
381382
{
382-
TLLM_CHECK_ERROR(!batchM, "UTMALDG.GATHER4 only supported for batch N.");
383-
384383
if (tg::mmaKindIsBlockFmt(options.mMmaKind))
385384
{
386385
int const numEltsPerSfRoute = batchM ? options.mSfBlockSizeA : options.mSfBlockSizeB;
@@ -392,8 +391,9 @@ inline bool checkAndUpdateBatchedGemmOptions(
392391

393392
if (!batchM || doesRouteImplUseNoRoute(options.mRouteImpl))
394393
{
395-
TLLM_CHECK_ERROR(options.mSfLayoutA == tg::SfLayout::R128c4,
396-
"options.mSfLayoutA has to be tg::SfLayout::R128c4 when not being routed");
394+
bool isSupportedSfLayoutA = options.mSfLayoutA == tg::SfLayout::R128c4;
395+
TLLM_CHECK_ERROR(isSupportedSfLayoutA, "options.mSfLayoutA has to be R128cX when not batch M or not routed",
396+
tg::sfLayoutToString(options.mSfLayoutA));
397397
}
398398
}
399399

@@ -422,12 +422,6 @@ inline bool checkAndUpdateBatchedGemmOptions(
422422
options.mK % options.mTileK == 0, "K must be a multiple of tileK when using Ldg based SF routing");
423423
}
424424

425-
if (options.mClusterDimX > 1 && batchM && options.mRouteSfsImpl.has_value())
426-
{
427-
TLLM_CHECK_ERROR(options.mRouteSfsImpl.value() != RouteImpl::Tma,
428-
"2CTA BatchedGemm does not support routing Sf along M dimension with TMA.");
429-
}
430-
431425
// Check if all elements in mBatchedM or mBatchedN are the same (uniform tokens per batch) and
432426
// set mIsUniformNumTokensPerBatch and mBatchStride.
433427
if (options.mIsUniformNumTokensPerBatch)

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/Enums.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ enum class EltwiseActType
107107
// act = relu(x0) ^ 2
108108
// where x0 is the output of the Gemm.
109109
Relu2,
110+
// Silu is defined as the following operation:
111+
// act = x0 * sigmoid(x0)
112+
// where x0 is the output of the Gemm.
113+
Silu,
110114
};
111115

112116
////////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)