Skip to content

Commit b3a7381

Browse files
authored
[https://nvbugs/6025177][fix] Fix KV cache issue (#12673)
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com>
1 parent 68a041a commit b3a7381

14 files changed

Lines changed: 389 additions & 38 deletions

File tree

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,33 @@ std::vector<BlockPtr> getAllSequenceBlocks(BlockPtr lastBlock)
8787
return sequenceBlocks;
8888
}
8989

90+
// Compute maximum number of tokens that have been computed by prefill and generation.
91+
// Accounts for chunked prefill to avoid storing state that hasn't been written to KV cache yet.
92+
// We call LlmRequest::getContextRemainingLength to see how many tokens are still waiting to be computed in prefill.
93+
// If this value is > 0 prefill is not finished yet, and number of computed tokens must be capped at the current context
94+
// position. If it is == 0, we are in generation mode, and number of computed tokens equals number of unique tokens
95+
// stored in request.
96+
SizeType32 getMaterializedUniqueTokenCountForReuse(
97+
VecUniqueTokens const& uniqueTokens, tensorrt_llm::batch_manager::LlmRequest const& llmRequest)
98+
{
99+
auto const totalUniqueTokenCount = static_cast<SizeType32>(uniqueTokens.size());
100+
if (llmRequest.getContextRemainingLength() > 0)
101+
{
102+
return std::min(totalUniqueTokenCount, llmRequest.getContextCurrentPosition());
103+
}
104+
return totalUniqueTokenCount;
105+
}
106+
107+
// Compute number of tokens that can be stored for reuse. The last computed token is never stored in KV cache, hence
108+
// cannot be stored for reuse. Number of tokens that can be stored for reuse is thus the greater of 0 or
109+
// getMaterializedUniqueTokenCountForReuse() - 1.
110+
SizeType32 getUsableUniqueTokenCountForReuse(
111+
VecUniqueTokens const& uniqueTokens, tensorrt_llm::batch_manager::LlmRequest const& llmRequest)
112+
{
113+
auto const materializedUniqueTokenCount = getMaterializedUniqueTokenCountForReuse(uniqueTokens, llmRequest);
114+
return materializedUniqueTokenCount > 0 ? materializedUniqueTokenCount - 1 : 0;
115+
}
116+
90117
} // namespace
91118

92119
namespace tensorrt_llm::batch_manager::kv_cache_manager
@@ -926,12 +953,9 @@ void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest co
926953
auto const& uniqueTokens = llmRequest.getUniqueTokens(beamIdx);
927954
TLLM_LOG_DEBUG("storeContextBlocks for request %lu on window %d with %d unique tokens", llmRequest.mRequestId,
928955
windowSize, uniqueTokens.size());
929-
// only store the tokens that have been completed
930-
size_t const completedTokens = llmRequest.getContextCurrentPosition();
931-
auto usableSize = std::min(completedTokens, uniqueTokens.size() - 1);
932-
956+
auto const usableUniqueTokenCount = getUsableUniqueTokenCountForReuse(uniqueTokens, llmRequest);
933957
auto blockedUniqueTokens
934-
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, getTokensPerBlock(), false);
958+
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableUniqueTokenCount, getTokensPerBlock(), false);
935959
auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest);
936960
(void) manager.storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
937961
}
@@ -2369,17 +2393,17 @@ std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
23692393
auto constexpr beamIdx = 0;
23702394
auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx);
23712395
auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize);
2372-
// TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't
2373-
// have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume
2374-
// the last token's state is not filled yet.
2375-
auto usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
2396+
2397+
auto usableUniqueTokenCount = getUsableUniqueTokenCountForReuse(uniqueTokens, *llmRequest);
23762398
if (isRecurrentState())
23772399
{
2378-
usableSize = std::min(llmRequest->getPromptLen() - 1, usableSize); // TODO: enable store for completed sequences
2400+
usableUniqueTokenCount = std::min(
2401+
llmRequest->getPromptLen() - 1, usableUniqueTokenCount); // TODO: enable store for completed sequences
23792402
}
23802403
TLLM_LOG_DEBUG("%s::storeBlocksForReuse: req=%lu, windowSize=%d, uniqueTokens.size()=%zu, usableSize=%zu",
2381-
mLogPrefix.c_str(), llmRequest->mRequestId, mWindowSize, uniqueTokens.size(), usableSize);
2382-
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
2404+
mLogPrefix.c_str(), llmRequest->mRequestId, mWindowSize, uniqueTokens.size(), usableUniqueTokenCount);
2405+
auto blockedUniqueTokens
2406+
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableUniqueTokenCount, mTokensPerBlock, true);
23832407
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
23842408

23852409
auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks);
@@ -2414,11 +2438,9 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
24142438
sequence.getRequestId());
24152439
}
24162440
auto const& uniqueTokens = llmRequest->getUniqueTokens(/*beamIdx=*/0);
2417-
// Only (length - 1) tokens of the sequence have their kv-state
2418-
// recorded in kv-cache. We assume the last token's state is not filled yet.
2419-
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
2420-
auto blockedUniqueTokens
2421-
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, /*allowPartial=*/true);
2441+
auto const usableUniqueTokenCount = getUsableUniqueTokenCountForReuse(uniqueTokens, *llmRequest);
2442+
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(
2443+
uniqueTokens, usableUniqueTokenCount, mTokensPerBlock, /*allowPartial=*/true);
24222444
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
24232445

24242446
std::vector<KVCacheBlock::IdType> cacheBlockIds(allocatedBlocks.size());

cpp/tensorrt_llm/nanobind/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ set(SRCS
2222
runtime/hostfunc.cpp
2323
runtime/moeBindings.cpp
2424
suffixAutomaton/bindings.cpp
25+
testing/kvCacheManagerTestUtilBinding.cpp
2526
testing/modelSpecBinding.cpp
2627
userbuffers/bindings.cpp
2728
thop/bindings.cpp

cpp/tensorrt_llm/nanobind/bindings.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "tensorrt_llm/nanobind/process_group/bindings.h"
4646
#include "tensorrt_llm/nanobind/runtime/bindings.h"
4747
#include "tensorrt_llm/nanobind/suffixAutomaton/bindings.h"
48+
#include "tensorrt_llm/nanobind/testing/kvCacheManagerTestUtilBinding.h"
4849
#include "tensorrt_llm/nanobind/testing/modelSpecBinding.h"
4950
#include "tensorrt_llm/nanobind/thop/bindings.h"
5051
#include "tensorrt_llm/nanobind/userbuffers/bindings.h"
@@ -512,6 +513,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
512513
tpb::Buffers::initBindings(mInternalBatchManager);
513514
tensorrt_llm::nanobind::runtime::initBindings(mInternalRuntime);
514515
tensorrt_llm::nanobind::testing::initBindings(mInternalTesting);
516+
tensorrt_llm::nanobind::testing::initKvCacheTestUtilBindings(mInternalTesting);
515517
tpb::initBindings(mInternalBatchManager);
516518

517519
tb::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(mInternalBatchManager);
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include "kvCacheManagerTestUtilBinding.h"
19+
#include "tensorrt_llm/nanobind/common/customCasters.h"
20+
#include "tensorrt_llm/testing/kvCacheManagerTestUtil.h"
21+
22+
#include <nanobind/nanobind.h>
23+
24+
namespace nb = nanobind;
25+
26+
namespace tensorrt_llm::nanobind::testing
27+
{
28+
29+
void initKvCacheTestUtilBindings(nb::module_& m)
30+
{
31+
m.def("simulate_prefill_completion_only_use_for_testing",
32+
&tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion, nb::arg("llm_request"),
33+
nb::call_guard<nb::gil_scoped_release>(),
34+
"NEVER USE IN PRODUCTION. Simulates prefill completion on an LlmRequest for test purposes.");
35+
}
36+
37+
} // namespace tensorrt_llm::nanobind::testing
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
20+
#include <nanobind/nanobind.h>
21+
22+
namespace nb = nanobind;
23+
24+
namespace tensorrt_llm::nanobind::testing
25+
{
26+
27+
void initKvCacheTestUtilBindings(nb::module_& m);
28+
29+
} // namespace tensorrt_llm::nanobind::testing
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
20+
#include "tensorrt_llm/batch_manager/llmRequest.h"
21+
22+
namespace tensorrt_llm::testing
23+
{
24+
25+
/// @brief Test utilities for KV cache manager unit tests. NEVER use in production code.
26+
class KvCacheManagerTestUtil
27+
{
28+
public:
29+
/// @brief Simulate completion of the prefill stage on an LlmRequest.
30+
///
31+
/// NEVER CALL FROM PRODUCTION CODE. This is solely for use in unit tests.
32+
///
33+
/// Most BlockManager/KVCacheManager functions (storeContextBlocks, releaseBlocks,
34+
/// removeSequence, releaseSequence) require prefill to be complete before they are
35+
/// called. This method updates llmRequest state as if prefill has just finished,
36+
/// allowing unit tests to invoke those functions correctly.
37+
static void simulatePrefillCompletion(batch_manager::LlmRequest& llmRequest)
38+
{
39+
llmRequest.setContextCurrentPosition(llmRequest.getPromptLen());
40+
}
41+
};
42+
43+
} // namespace tensorrt_llm::testing

cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "tensorrt_llm/executor/executor.h"
2929
#include "tensorrt_llm/executor/requestUtils.h"
3030
#include "tensorrt_llm/executor/types.h"
31+
#include "tensorrt_llm/testing/kvCacheManagerTestUtil.h"
3132

3233
#include <NvInferPlugin.h>
3334

@@ -401,6 +402,7 @@ int runTest(CapacityScheduler& capacityScheduler,
401402

402403
if (llmReq->getContextRemainingLength() == 0)
403404
{
405+
tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion(*llmReq);
404406
kvCacheManager->storeContextBlocks(*llmReq);
405407
if (crossKvCacheManager)
406408
{

0 commit comments

Comments
 (0)