Skip to content

Commit a622e30

Browse files
authored
[TRTLLM-11538][feat] Blackwell custom mask fmha support (#12958)
Signed-off-by: qgai <qgai@nvidia.com>
1 parent 205920d commit a622e30

14 files changed

Lines changed: 421 additions & 76 deletions

File tree

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,6 +3157,7 @@ int AttentionOp::initialize() noexcept
31573157
fixedParams.isSpecDecoding = mIsSpecDecodingEnabled;
31583158
fixedParams.hasAlibi = isALiBi();
31593159
fixedParams.useTllmGenSparseAttention = useTllmGenSparseAttention();
3160+
fixedParams.specDecodingTargetMaxGenLen = mSpecDecodingTargetMaxGenLen;
31603161

31613162
mXqaDispatcher.reset(new XqaDispatcher(fixedParams));
31623163

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,8 @@ class AttentionOp
491491
bool mIsSpecDecTree = true;
492492
bool mSpecDecodingIsGenerationLengthVariable = false;
493493
int32_t mSpecDecodingMaxGenerationLength = 1;
494+
// Static spec-dec tree length used by FMHA autotuning.
495+
int32_t mSpecDecodingTargetMaxGenLen = 0;
494496
bool mIsMLAEnabled = false;
495497
bool mIsGenerationMLA = false;
496498
bool mUseGenFlashMLA = false;
@@ -559,13 +561,14 @@ class AttentionOp
559561
mCrossAttention, mMaxDistance, mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput,
560562
mFP8ContextMLA, mFP8GenerationMLA, mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask,
561563
mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable,
562-
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mUseSparseAttention,
563-
mUseTllmGenSparseAttentionPaged, mUseTllmGenSparseAttention, mMLAParams.data(), mCpSize, mCpRank, mCpGroup,
564-
mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank,
565-
mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache,
566-
mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1),
567-
mSkipSoftmaxThresholdScaleFactorPrefill, mSkipSoftmaxThresholdScaleFactorDecode, mSageAttnNumEltsPerBlkQ,
568-
mSageAttnNumEltsPerBlkK, mSageAttnNumEltsPerBlkV, mSageAttnQkInt8);
564+
mSpecDecodingMaxGenerationLength, mSpecDecodingTargetMaxGenLen, mIsMLAEnabled, mIsGenerationMLA,
565+
mUseGenFlashMLA, mUseSparseAttention, mUseTllmGenSparseAttentionPaged, mUseTllmGenSparseAttention,
566+
mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin,
567+
mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA,
568+
mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
569+
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1), mSkipSoftmaxThresholdScaleFactorPrefill,
570+
mSkipSoftmaxThresholdScaleFactorDecode, mSageAttnNumEltsPerBlkQ, mSageAttnNumEltsPerBlkK,
571+
mSageAttnNumEltsPerBlkV, mSageAttnQkInt8);
569572
};
570573

571574
private:

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,8 @@ class TllmGenFmhaKernel
997997
options.mIsCustomSpecDecodingGen = !isContext && params.mMaxSeqLenQ > 1 && params.mIsSpecDecTree;
998998
options.mIsCausalSpecDecodingGen = !isContext && params.mMaxSeqLenQ > 1 && !params.mIsSpecDecTree;
999999
options.mNumSpecDecodingTokens = !isContext && params.mMaxSeqLenQ > 1 ? params.mMaxSeqLenQ : 0;
1000+
// Carry static tree length into FMHA kernel selection.
1001+
options.mSpecDecodingTargetMaxGenLen = params.mSpecDecodingTargetMaxGenLen;
10001002

10011003
options.mIsTrtllmLayout = true;
10021004
}
@@ -1020,17 +1022,43 @@ class TllmGenFmhaKernel
10201022
// loop. And the number of loops are not the same in different tasks.
10211023
sstream << "\"checksTaskSchedules\": false,\n";
10221024

1025+
bool hasCompileDefs = false;
1026+
auto writeCompileDef = [&](char const* compileDef)
1027+
{
1028+
if (!hasCompileDefs)
1029+
{
1030+
sstream << "\"compileDefs\": [";
1031+
hasCompileDefs = true;
1032+
}
1033+
else
1034+
{
1035+
sstream << ", ";
1036+
}
1037+
sstream << "\"" << compileDef << "\"";
1038+
};
1039+
10231040
if (options.mIsExportingCubin)
10241041
{
1025-
sstream << "\"compileDefs\": [\"-DTLLM_EXPORT_CUBIN\"],\n";
1042+
writeCompileDef("-DTLLM_EXPORT_CUBIN");
10261043
}
10271044

10281045
// Set compile flags for E2M1 KV kernel benchmark.
10291046
// NOTE(tizheng): This is to be removed after compiler fixes PTX exposure of QMUL4. See Fp4Utils.h for details.
10301047
if (options.mChecksResults == 0 && options.mDtypeKv == tg::Dtype::E2m1)
10311048
{
10321049
TLLM_LOG_INFO("Forcing -DTLLM_BENCHMARK_E2M1_KV_CACHE for E2m1 Kv. The results are not correct.");
1033-
sstream << "\"compileDefs\": [\"-DTLLM_BENCHMARK_E2M1_KV_CACHE\"],\n";
1050+
writeCompileDef("-DTLLM_BENCHMARK_E2M1_KV_CACHE");
1051+
}
1052+
1053+
// SwapsMmaAb NVRTC kernels already emit __launch_bounds__; avoid a CUDA 13 .reqntid/.maxntid conflict.
1054+
if (shouldUseNvrtc(options) && options.mFmhaKernelType == FmhaKernelType::SwapsMmaAbForGeneration)
1055+
{
1056+
writeCompileDef("-DTLLM_DISABLE_BLOCK_SIZE");
1057+
}
1058+
1059+
if (hasCompileDefs)
1060+
{
1061+
sstream << "],\n";
10341062
}
10351063

10361064
// Enable programmatic dependent launch.

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaRunnerParams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ struct TllmGenFmhaRunnerParams
342342
// When seqlensQPtr[i] < mPackedMaskMaxSeqLenQ, the packed mask tensor has
343343
// row stride ceilDiv(mPackedMaskMaxSeqLenQ, 32) rather than ceilDiv(seqLenQ, 32).
344344
int32_t mPackedMaskMaxSeqLenQ = 0;
345+
int32_t mSpecDecodingTargetMaxGenLen = 0;
345346

346347
// set the attention mask type
347348
TllmGenFmhaRunnerParams& setAttentionMaskType(std::int8_t maskType)

0 commit comments

Comments
 (0)