Skip to content

Commit 7746bbc

Browse files
committed
fix ups
Signed-off-by: Athena Cai <athenac@nvidia.com>
1 parent 596d453 commit 7746bbc

10 files changed

Lines changed: 171 additions & 69 deletions

File tree

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -681,15 +681,13 @@ class GenerationRequest
681681
++mNumFrontBlocksRemovedPerWindow.at(windowSize);
682682
}
683683

684-
//! \brief Advance ``mNumFrontBlocksRemoved`` without touching cache blocks.
684+
//! \brief Advance the per-window front-block counter without touching cache blocks.
685685
//! \details Used by ``BlockManager::releasePrefixBlocks`` to advance the
686-
//! shared front-block counter once after every ``WindowBlockManager`` has
687-
//! processed the same prefix range. Has clearer intent than calling
688-
//! ``removeFrontBlock`` with a sentinel ``windowSize`` value, and is robust
689-
//! to future changes that consume the ``windowSize`` argument.
690-
void incrementNumFrontBlocksRemoved()
686+
//! single-window front-block counter once after every ``WindowBlockManager`` has
687+
//! processed the same prefix range.
688+
void incrementNumFrontBlocksRemoved(SizeType32 windowSize)
691689
{
692-
++mNumFrontBlocksRemoved;
690+
++mNumFrontBlocksRemovedPerWindow.at(windowSize);
693691
}
694692

695693
void removeLastBlock(SizeType32 windowSize)
@@ -989,7 +987,7 @@ class WindowBlockManager
989987
//! for blocks whose data has already been transferred. Reuses the
990988
//! detachFrontBlock mechanism (decRefCount + eviction policy release).
991989
//! Called by BlockManager::releasePrefixBlocks which coordinates the
992-
//! shared mNumFrontBlocksRemoved counter across all window managers.
990+
//! single-window front-block counter across all window managers.
993991
void releasePrefixBlocks(GenerationRequest& sequence, SizeType32 startIdx, SizeType32 numBlocks);
994992

995993
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
@@ -1535,7 +1533,7 @@ class BlockManager
15351533

15361534
//! \brief Release the first numBlocks prefix blocks of a sequence.
15371535
//! \details Mirrors detachFrontBlock logic: decRefCount + eviction policy
1538-
//! release for each prefix block. The mNumFrontBlocksRemoved counter on
1536+
//! release for each prefix block. The front-block counter on
15391537
//! GenerationRequest ensures releaseBlocks (called during removeSequence)
15401538
//! skips already-freed prefix blocks.
15411539
void releasePrefixBlocks(GenerationRequest& sequence, SizeType32 numBlocks);

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2906,22 +2906,22 @@ void BlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeType32 n
29062906
// today (gated by should_store_blocks: not is_vswa in the executor and
29072907
// beamWidth == 1 assertion in WindowBlockManager::releasePrefixBlocks).
29082908
//
2909+
auto const windowSize = mWindowBlockManagers.cbegin()->first;
29092910
// Snapshot the counter before iterating so that every WindowBlockManager
29102911
// releases the same range. Without this, the first manager would advance
2911-
// the shared mNumFrontBlocksRemoved counter and subsequent managers would
2912-
// see the counter already at the target, skipping their own blocks.
2913-
SizeType32 const startIdx = sequence.getNumFrontBlocksRemoved();
2912+
// the single-window front-block counter and subsequent managers would see
2913+
// the counter already at the target, skipping their own blocks.
2914+
SizeType32 const startIdx = sequence.getNumFrontBlocksRemoved(windowSize);
29142915
for (auto& [_, manager] : mWindowBlockManagers)
29152916
{
29162917
manager.releasePrefixBlocks(sequence, startIdx, numBlocks);
29172918
}
2918-
// Advance the shared counter once, after all managers have released.
2919+
// Advance the single-window counter once, after all managers have released.
29192920
// Uses incrementNumFrontBlocksRemoved (counter-only) instead of
2920-
// removeFrontBlock so the intent is explicit and we do not depend on
2921-
// removeFrontBlock ignoring its windowSize argument.
2922-
while (sequence.getNumFrontBlocksRemoved() < numBlocks)
2921+
// removeFrontBlock so the intent is explicit.
2922+
while (sequence.getNumFrontBlocksRemoved(windowSize) < numBlocks)
29232923
{
2924-
sequence.incrementNumFrontBlocksRemoved();
2924+
sequence.incrementNumFrontBlocksRemoved(windowSize);
29252925
}
29262926
}
29272927

@@ -3746,23 +3746,30 @@ void WindowBlockManager::releasePrefixBlocks(GenerationRequest& sequence, SizeTy
37463746
auto& allocatedBlocks = mAllocatedBlocksPerSeq.at(requestId);
37473747
SizeType32 const target = std::min(numBlocks, static_cast<SizeType32>(allocatedBlocks.size()));
37483748

3749-
// Release blocks in range [startIdx, target). The shared
3750-
// mNumFrontBlocksRemoved counter is advanced by BlockManager after
3749+
// Release blocks in range [startIdx, target). The single-window
3750+
// front-block counter is advanced by BlockManager after
37513751
// all WindowBlockManagers have processed the same range.
37523752
for (SizeType32 blockIdx = startIdx; blockIdx < target; ++blockIdx)
37533753
{
37543754
auto& block = allocatedBlocks.at(blockIdx);
3755+
auto releasedBlock = block;
37553756

37563757
TLLM_LOG_DEBUG("%s::releasePrefixBlocks - Releasing block %d from sequence %lu", mLogPrefix.c_str(),
3757-
block->getBlockId(), requestId);
3758+
releasedBlock->getBlockId(), requestId);
37583759

3759-
if (block->hasRefs())
3760+
// Replace the sequence slot with a placeholder, matching detachFrontBlock().
3761+
// removeSequence later walks allocatedBlocks in releaseBlocks(); leaving the
3762+
// real block here would release it a second time and corrupt the eviction
3763+
// policy's free-block count.
3764+
block = KVCacheBlock::createPlaceholder();
3765+
3766+
if (releasedBlock->hasRefs())
37603767
{
3761-
block->decRefCount();
3768+
releasedBlock->decRefCount();
37623769
}
3763-
if (!block->hasRefs())
3770+
if (!releasedBlock->hasRefs())
37643771
{
3765-
mEvictionPolicy->releaseBlock(block);
3772+
mEvictionPolicy->releaseBlock(releasedBlock);
37663773
}
37673774
}
37683775
}
@@ -3945,8 +3952,8 @@ std::optional<KVCacheBlock::IdType> KVCacheManager::removeSequence(
39453952

39463953
void KVCacheManager::releasePrefixBlocks(RequestIdType requestId, SizeType32 numBlocks)
39473954
{
3948-
// Hard precondition: BlockManager::releasePrefixBlocks advances the shared
3949-
// mNumFrontBlocksRemoved counter to numBlocks for every WindowBlockManager,
3955+
// Hard precondition: BlockManager::releasePrefixBlocks advances the
3956+
// single-window front-block counter to numBlocks for every WindowBlockManager,
39503957
// even when a window has fewer than numBlocks allocated. Under variable
39513958
// sliding window attention (VSWA), that would cause WindowBlockManager::
39523959
// releaseBlocks (called during removeSequence) to underrun rbegin() and

cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,59 @@ TEST_F(KVCacheManagerTest, BlockManagerTest)
270270
std::runtime_error);
271271
}
272272

273+
TEST_F(KVCacheManagerTest, BlockManagerReleasePrefixBlocksDoesNotDoubleFreeOnTeardown)
274+
{
275+
auto constexpr numLayers = 12;
276+
auto constexpr numKvHeads = 6;
277+
auto constexpr sizePerHead = 128;
278+
auto constexpr tokensPerBlock = 4;
279+
auto constexpr blocksInPrimaryPool = 8;
280+
auto constexpr blocksInSecondaryPool = 0;
281+
auto constexpr maxNumSequences = 8;
282+
auto const stream = std::make_shared<tr::CudaStream>();
283+
284+
auto constexpr beamWidth = 1;
285+
auto constexpr numBlocksPerBeam = 4;
286+
auto constexpr numTokens = tokensPerBlock * numBlocksPerBeam;
287+
auto constexpr maxAttentionWindow = numTokens;
288+
289+
auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};
290+
291+
BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
292+
maxNumSequences, stream, maxAttentionWindow, beamWidth,
293+
std::vector<BlockManager::SizeType32>{maxAttentionWindow}, nvinfer1::DataType::kHALF, 0, maxAttentionWindow);
294+
blockManager.allocatePools(false);
295+
296+
SizeType32 constexpr maxNewTokens{0};
297+
tr::SamplingConfig const samplingConfig{beamWidth};
298+
bool constexpr isStreaming{false};
299+
300+
auto tokens = std::make_shared<VecTokens>();
301+
for (SizeType32 i = 0; i < numTokens; ++i)
302+
{
303+
tokens->push_back(i);
304+
}
305+
306+
LlmRequest::RequestIdType constexpr requestId{42};
307+
auto llmReq = std::make_shared<LlmRequest>(requestId, maxNewTokens, tokens, samplingConfig, isStreaming);
308+
GenerationRequest seq{requestId, numTokens, beamWidth, blockManager.getWindowSizesMetadata()};
309+
310+
(void) blockManager.addSequenceBatch(
311+
{&seq}, {numTokens}, {numBlocksPerBeam}, {std::ref(*llmReq)}, maxAttentionWindow, /*isEnableBlockReuse=*/false);
312+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocksPerBeam);
313+
314+
blockManager.releasePrefixBlocks(seq, 2);
315+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - 2);
316+
317+
// releasePrefixBlocks has cumulative semantics. This should release only
318+
// one additional block rather than releasing the first two again.
319+
blockManager.releasePrefixBlocks(seq, 3);
320+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - 1);
321+
322+
blockManager.releaseBlocks(seq);
323+
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
324+
}
325+
273326
template <typename T>
274327
void writePatternToOffloadedBlocksDRAM(T* rawBlockPtr, int blockSize, int mask)
275328
{

tensorrt_llm/_torch/disaggregation/native/transfer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,10 @@ def disagg_request_id(self) -> int:
11681168
def status(self) -> SessionStatus:
11691169
if self._terminal_status is not None:
11701170
return self._terminal_status
1171+
if self._exception is not None or any(t.status == TaskStatus.ERROR for t in self.kv_tasks):
1172+
return SessionStatus.ERROR
1173+
if self.aux_task is not None and self.aux_task.status == TaskStatus.ERROR:
1174+
return SessionStatus.ERROR
11711175
kv_all_transferred = bool(self.kv_tasks) and all(
11721176
t.status == TaskStatus.TRANSFERRED for t in self.kv_tasks
11731177
)
@@ -1755,15 +1759,15 @@ def process_kv_agent_result(
17551759
)
17561760

17571761
def process_aux_agent_result(self, _peer_rank: int, status: AgentResult):
1758-
# Aux is session-level (not per-slice); expected_transfers is identical
1759-
# across all kv_tasks, so any task provides the right count.
1762+
# Aux is session-level (not per-slice); use the final KV task's
1763+
# expected transfer count so chunked sessions wait for all senders.
17601764
with self.lock:
17611765
if not self._kv_tasks:
17621766
logger.warning(
17631767
f"Aux result received before any KV tasks for request {self.request_id}"
17641768
)
17651769
return
1766-
task = self._kv_tasks[0]
1770+
task = self._kv_tasks[-1]
17671771
if status == AgentResult.SUCCESS:
17681772
self._aux_count += 1
17691773

tensorrt_llm/_torch/disaggregation/transceiver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,10 @@ def _make_chunk_callback(self) -> Optional[Callable]:
353353
release_queue = self._pending_prefix_releases
354354

355355
def _on_chunk_transferred(request_id: int, chunk_block_offset: int, num_blocks: int):
356+
logger.debug(
357+
f"Early release _on_chunk_transferred: request_id: {request_id}, "
358+
f"chunk_block_offset: {chunk_block_offset}, num_blocks: {num_blocks}"
359+
)
356360
cumulative_blocks = chunk_block_offset + num_blocks
357361
release_queue.put((request_id, cumulative_blocks))
358362

tensorrt_llm/_torch/models/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .modeling_deepseekv3 import DeepseekV3ForCausalLM
1616
from .modeling_exaone4 import Exaone4ForCausalLM
1717
from .modeling_exaone4_5 import Exaone4_5_ForConditionalGeneration
18-
from .modeling_exaone_moe import ExaoneMoeForCausalLM
1918
from .modeling_gemma3 import Gemma3ForCausalLM
2019
from .modeling_gemma3vl import Gemma3VLM
2120
from .modeling_glm import Glm4MoeForCausalLM
@@ -57,6 +56,11 @@
5756
from .modeling_utils import get_model_architecture
5857
from .modeling_vila import VilaModel
5958

59+
try:
60+
from .modeling_exaone_moe import ExaoneMoeForCausalLM
61+
except ImportError:
62+
ExaoneMoeForCausalLM = None
63+
6064
# Note: for better readiblity, this should have same order as imports above
6165
__all__ = [
6266
"AfmoeForCausalLM",
@@ -67,7 +71,6 @@
6771
"DeepseekV3ForCausalLM",
6872
"Exaone4ForCausalLM",
6973
"Exaone4_5_ForConditionalGeneration",
70-
"ExaoneMoeForCausalLM",
7174
"Gemma3ForCausalLM",
7275
"Gemma3VLM",
7376
"HCXVisionForCausalLM",
@@ -116,6 +119,9 @@
116119
"Step3p7VLForConditionalGeneration",
117120
]
118121

122+
if ExaoneMoeForCausalLM is not None:
123+
__all__.append("ExaoneMoeForCausalLM")
124+
119125
if transformers.__version__ >= "4.45.1":
120126
from .modeling_mllama import MllamaForConditionalGeneration # noqa
121127

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,49 @@ def test_kv_cache_v2_nixl_python(self):
733733
self.MODEL_PATH) as llm:
734734
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
735735

736+
@skip_pre_hopper
737+
@pytest.mark.skip_less_device(2)
738+
@parametrize_with_ids("chunk_size_blocks", [64])
739+
@parametrize_with_ids("enable_block_reuse", [False, True])
740+
def test_chunked_kv_transfer_nixl_python_accuracy(self,
741+
chunk_size_blocks: int,
742+
enable_block_reuse: bool):
743+
"""Test chunked KV transfer accuracy using Python transceiver and C++ KVCacheManager."""
744+
kv_cache_config = {
745+
"use_kv_cache_manager_v2": False,
746+
"enable_block_reuse": enable_block_reuse,
747+
}
748+
cache_transceiver_config = {
749+
"backend": "NIXL",
750+
"transceiver_runtime": "PYTHON",
751+
"max_tokens_in_buffer": 4096,
752+
"chunk_size_blocks": chunk_size_blocks,
753+
}
754+
ctx_server_config = {
755+
"disable_overlap_scheduler": True,
756+
"kv_cache_config": dict(kv_cache_config),
757+
"cache_transceiver_config": dict(cache_transceiver_config),
758+
}
759+
gen_server_config = {
760+
"disable_overlap_scheduler": False,
761+
"kv_cache_config": dict(kv_cache_config),
762+
"cache_transceiver_config": dict(cache_transceiver_config),
763+
}
764+
disaggregated_server_config = {
765+
"hostname": "localhost",
766+
"backend": "pytorch",
767+
"context_servers": {
768+
"num_instances": 1,
769+
},
770+
"generation_servers": {
771+
"num_instances": 1,
772+
},
773+
}
774+
with launch_disaggregated_llm(disaggregated_server_config,
775+
ctx_server_config, gen_server_config,
776+
self.MODEL_PATH) as llm:
777+
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
778+
736779
@pytest.mark.skip_less_device(2)
737780
def test_ngram(self):
738781
speculative_decoding_config = {

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ l0_dgx_b200:
1818
- unittest/_torch/misc/test_autotuner.py::test_autotuner_distributed_strategy
1919
- accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp2-CUTLASS]
2020
- accuracy/test_llm_api_pytorch.py::TestQwen3_5_35B_A3B::test_bf16[tp2-TRTLLM]
21+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_chunked_kv_transfer_nixl_python_accuracy
2122
# ------------- KV Cache V2 Scheduler IT (multi-GPU) ---------------
2223
- kv_cache/test_kv_cache_v2_scheduler.py::TestKVCacheV2DSv3Lite::test_mtp_draft_tokens
2324
- kv_cache/test_kv_cache_v2_scheduler.py::TestKVCacheV2DSv3Lite::test_mtp_chunked_draft_tokens

0 commit comments

Comments
 (0)