@@ -87,6 +87,33 @@ std::vector<BlockPtr> getAllSequenceBlocks(BlockPtr lastBlock)
8787 return sequenceBlocks;
8888}
8989
90+ // Compute maximum number of tokens that have been computed by prefill and generation.
91+ // Accounts for chunked prefill to avoid storing state that hasn't been written to KV cache yet.
92+ // We call LlmRequest::getContextRemainingLength to see how many tokens are still waiting to be computed in prefill.
93+ // If this value is > 0 prefill is not finished yet, and number of computed tokens must be capped at the current context
94+ // position. If it is == 0, we are in generation mode, and number of computed tokens equals number of unique tokens
95+ // stored in request.
96+ SizeType32 getMaterializedUniqueTokenCountForReuse (
97+ VecUniqueTokens const & uniqueTokens, tensorrt_llm::batch_manager::LlmRequest const & llmRequest)
98+ {
99+ auto const totalUniqueTokenCount = static_cast <SizeType32>(uniqueTokens.size ());
100+ if (llmRequest.getContextRemainingLength () > 0 )
101+ {
102+ return std::min (totalUniqueTokenCount, llmRequest.getContextCurrentPosition ());
103+ }
104+ return totalUniqueTokenCount;
105+ }
106+
107+ // Compute number of tokens that can be stored for reuse. The last computed token is never stored in KV cache, hence
108+ // cannot be stored for reuse. Number of tokens that can be stored for reuse is thus the greater of 0 or
109+ // getMaterializedUniqueTokenCountForReuse() - 1.
110+ SizeType32 getUsableUniqueTokenCountForReuse (
111+ VecUniqueTokens const & uniqueTokens, tensorrt_llm::batch_manager::LlmRequest const & llmRequest)
112+ {
113+ auto const materializedUniqueTokenCount = getMaterializedUniqueTokenCountForReuse (uniqueTokens, llmRequest);
114+ return materializedUniqueTokenCount > 0 ? materializedUniqueTokenCount - 1 : 0 ;
115+ }
116+
90117} // namespace
91118
92119namespace tensorrt_llm ::batch_manager::kv_cache_manager
@@ -926,12 +953,9 @@ void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest co
926953 auto const & uniqueTokens = llmRequest.getUniqueTokens (beamIdx);
927954 TLLM_LOG_DEBUG (" storeContextBlocks for request %lu on window %d with %d unique tokens" , llmRequest.mRequestId ,
928955 windowSize, uniqueTokens.size ());
929- // only store the tokens that have been completed
930- size_t const completedTokens = llmRequest.getContextCurrentPosition ();
931- auto usableSize = std::min (completedTokens, uniqueTokens.size () - 1 );
932-
956+ auto const usableUniqueTokenCount = getUsableUniqueTokenCountForReuse (uniqueTokens, llmRequest);
933957 auto blockedUniqueTokens
934- = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize , getTokensPerBlock (), false );
958+ = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableUniqueTokenCount , getTokensPerBlock (), false );
935959 auto blockKeys = buildBlockKeys (blockedUniqueTokens, llmRequest);
936960 (void ) manager.storeBlocks (std::move (blockKeys), cacheBlockIds[beamIdx]);
937961 }
@@ -2369,17 +2393,17 @@ std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
23692393 auto constexpr beamIdx = 0 ;
23702394 auto const & uniqueTokens = llmRequest->getUniqueTokens (beamIdx);
23712395 auto const & cacheBlockIds = sequence.getCacheBlockIds (mWindowSize );
2372- // TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't
2373- // have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume
2374- // the last token's state is not filled yet.
2375- auto usableSize = static_cast <runtime::SizeType32>(uniqueTokens.size ()) - 1 ;
2396+
2397+ auto usableUniqueTokenCount = getUsableUniqueTokenCountForReuse (uniqueTokens, *llmRequest);
23762398 if (isRecurrentState ())
23772399 {
2378- usableSize = std::min (llmRequest->getPromptLen () - 1 , usableSize); // TODO: enable store for completed sequences
2400+ usableUniqueTokenCount = std::min (
2401+ llmRequest->getPromptLen () - 1 , usableUniqueTokenCount); // TODO: enable store for completed sequences
23792402 }
23802403 TLLM_LOG_DEBUG (" %s::storeBlocksForReuse: req=%lu, windowSize=%d, uniqueTokens.size()=%zu, usableSize=%zu" ,
2381- mLogPrefix .c_str (), llmRequest->mRequestId , mWindowSize , uniqueTokens.size (), usableSize);
2382- auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock , true );
2404+ mLogPrefix .c_str (), llmRequest->mRequestId , mWindowSize , uniqueTokens.size (), usableUniqueTokenCount);
2405+ auto blockedUniqueTokens
2406+ = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableUniqueTokenCount, mTokensPerBlock , true );
23832407 auto blockKeys = buildBlockKeys (blockedUniqueTokens, *llmRequest);
23842408
23852409 auto [numStored, pinnedBlockIds] = storeBlocks (std::move (blockKeys), cacheBlockIds[beamIdx], pinBlocks);
@@ -2414,11 +2438,9 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
24142438 sequence.getRequestId ());
24152439 }
24162440 auto const & uniqueTokens = llmRequest->getUniqueTokens (/* beamIdx=*/ 0 );
2417- // Only (length - 1) tokens of the sequence have their kv-state
2418- // recorded in kv-cache. We assume the last token's state is not filled yet.
2419- auto const usableSize = static_cast <runtime::SizeType32>(uniqueTokens.size ()) - 1 ;
2420- auto blockedUniqueTokens
2421- = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock , /* allowPartial=*/ true );
2441+ auto const usableUniqueTokenCount = getUsableUniqueTokenCountForReuse (uniqueTokens, *llmRequest);
2442+ auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(
2443+ uniqueTokens, usableUniqueTokenCount, mTokensPerBlock , /* allowPartial=*/ true );
24222444 auto blockKeys = buildBlockKeys (blockedUniqueTokens, *llmRequest);
24232445
24242446 std::vector<KVCacheBlock::IdType> cacheBlockIds (allocatedBlocks.size ());
0 commit comments