Skip to content

Commit cedaf8d

Browse files
peihu-nvsuyoggupta
authored andcommitted
[https://nvbugs/5961736][fix] Prebuild disagg ctx response to avoid ctx_request_id race (NVIDIA#12466)
Signed-off-by: peihengh <259410613+peihu-nv@users.noreply.github.com>
1 parent d000a61 commit cedaf8d

5 files changed

Lines changed: 123 additions & 7 deletions

File tree

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1717,7 +1717,8 @@ class GenericLlmRequest
17171717

17181718
[[nodiscard]] bool isFinished() const noexcept
17191719
{
1720-
return isGenerationCompleteState() || mState == LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS;
1720+
return isGenerationCompleteState() || mState == LlmRequestState::kDISAGG_CONTEXT_TRANS_IN_PROGRESS
1721+
|| isDisaggContextCompleteState();
17211722
}
17221723

17231724
/// Returns true if finished_reason is length for all beams

cpp/tensorrt_llm/batch_manager/llmRequest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int
8989

9090
auto const maxNbTokens = getMaxBeamNumTokens();
9191

92-
if (isDisaggContextTransmissionState() && isContextOnlyRequest())
92+
if ((isDisaggContextTransmissionState() || isDisaggContextCompleteState()) && isContextOnlyRequest())
9393
{
9494
auto const reqBeamWidth = mSamplingConfig.beamWidth;
9595
std::vector<TokenIdType> firstGenTokens;

cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,44 @@ TEST_P(ParamTest, createResponse)
751751
}
752752
}
753753

754+
// Regression test for nvbug/5961736: createResult() must produce a valid
755+
// response with contextPhaseParams when the request is in
756+
// kDISAGG_CONTEXT_COMPLETE, not just kDISAGG_CONTEXT_TRANS_IN_PROGRESS.
757+
// Without the fix, createResult() returns nullopt for CONTEXT_COMPLETE,
758+
// causing ctx_request_id=None in the disaggregated serving response.
759+
TEST_F(LlmRequestTest, createResultDisaggContextComplete)
760+
{
761+
VecTokens inputTokens{1, 2, 3, 4, 5};
762+
SizeType32 maxNewTokens{10};
763+
texec::IdType requestId{42};
764+
765+
// Build an executor::Request and configure it as context-only with ContextPhaseParams.
766+
texec::Request execReq(inputTokens, maxNewTokens);
767+
execReq.setRequestType(texec::RequestType::REQUEST_TYPE_CONTEXT_ONLY);
768+
texec::ContextPhaseParams ctxParams({100}, requestId, static_cast<void*>(nullptr), std::nullopt);
769+
execReq.setContextPhaseParams(std::move(ctxParams));
770+
771+
tb::LlmRequest llmReq(requestId, execReq);
772+
EXPECT_TRUE(llmReq.isContextOnlyRequest());
773+
774+
// Add a generated token (required by createResult's firstGenTokens extraction).
775+
llmReq.addNewTokens(VecTokens{42});
776+
777+
// Verify isFinished() covers DISAGG_CONTEXT_COMPLETE.
778+
llmReq.setState(tb::LlmRequestState::kDISAGG_CONTEXT_COMPLETE);
779+
EXPECT_TRUE(llmReq.isFinished());
780+
781+
// This is the regression case — without the fix, createResult() returns nullopt
782+
// because DISAGG_CONTEXT_COMPLETE was not handled by createResult's early guard
783+
// or its context-phase branch.
784+
auto response = llmReq.createResult(/*useFastLogits=*/false, /*mpiWorldRank=*/0);
785+
ASSERT_TRUE(response.has_value()) << "createResult() must not return nullopt for DISAGG_CONTEXT_COMPLETE";
786+
EXPECT_TRUE(response->contextPhaseParams.has_value())
787+
<< "contextPhaseParams must be populated for context-only DISAGG_CONTEXT_COMPLETE requests";
788+
EXPECT_EQ(response->contextPhaseParams->getReqId(), requestId);
789+
EXPECT_TRUE(response->isSequenceFinal);
790+
}
791+
754792
INSTANTIATE_TEST_SUITE_P(LlmRequestTest, ParamTest,
755793
testing::Combine(
756794
// TODO: Support and add coverage for streamLLM

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3530,6 +3530,9 @@ def _handle_first_token_response(self, scheduled_batch):
35303530
def _handle_responses(self):
35313531
new_responses = []
35323532
requests_to_terminate = []
3533+
# Requests terminated by _check_disagg_ctx_cache_transfer_status (DISAGG_CONTEXT_COMPLETE);
3534+
# included in the return value for stats but not re-terminated here.
3535+
requests_finished_by_transfer = []
35333536
new_active_requests = []
35343537
logger.debug(
35353538
f'------before _handle_responses, rank = {self.dist.rank}, output = {self.active_requests}'
@@ -3618,11 +3621,18 @@ def _handle_responses(self):
36183621
# If partial reuse is enabled, and the KV cache manager is not VSWA, and the PP size is 1,
36193622
# then we need to terminate the request. TODO: Remove this once disagg support from KVCache reuse
36203623
# path is fixed.
3621-
if self.enable_partial_reuse_for_disagg and not self.kv_cache_manager.is_vswa and self.dist.pp_size == 1:
3624+
force_terminate_for_partial_reuse = (
3625+
self.enable_partial_reuse_for_disagg
3626+
and not self.kv_cache_manager.is_vswa
3627+
and self.dist.pp_size == 1)
3628+
if request.is_disagg_context_complete_state:
3629+
# Already terminated by _check_disagg_ctx_cache_transfer_status;
3630+
# track for stats only to avoid double-free (nvbug/5961736).
3631+
requests_finished_by_transfer.append(request)
3632+
elif force_terminate_for_partial_reuse:
3633+
requests_to_terminate.append(request)
3634+
elif not request.is_disagg_context_transmission_state:
36223635
requests_to_terminate.append(request)
3623-
else:
3624-
if not request.is_disagg_context_transmission_state:
3625-
requests_to_terminate.append(request)
36263636
else:
36273637
new_active_requests.append(request)
36283638

@@ -3632,7 +3642,7 @@ def _handle_responses(self):
36323642
self._enqueue_responses(new_responses)
36333643
for request in requests_to_terminate:
36343644
self._terminate_request(request)
3635-
return requests_to_terminate
3645+
return requests_to_terminate + requests_finished_by_transfer
36363646

36373647
def _await_any_response(self,
36383648
timeout: Optional[float] = None

tests/unittest/_torch/executor/test_py_executor.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,70 @@ def test_getter_methods(mock_executor):
178178
assert mock_executor.get_expected_num_active_requests() == 5
179179
assert mock_executor._get_new_active_requests_queue_latency() == 10.5
180180
assert mock_executor.get_waiting_queue_size() == 1
181+
182+
183+
def _classify_termination(request, enable_partial_reuse_for_disagg, is_vswa, pp_size):
184+
"""Reproduce the termination logic from _handle_responses (py_executor.py).
185+
186+
Returns:
187+
"terminate" | "stats_only" | "skip"
188+
"""
189+
force_terminate_for_partial_reuse = (
190+
enable_partial_reuse_for_disagg and not is_vswa and pp_size == 1
191+
)
192+
if request.is_disagg_context_complete_state:
193+
return "stats_only"
194+
elif force_terminate_for_partial_reuse:
195+
return "terminate"
196+
elif not request.is_disagg_context_transmission_state:
197+
return "terminate"
198+
return "skip"
199+
200+
201+
def _make_request(complete_state, transmission_state):
202+
req = Mock()
203+
req.is_disagg_context_complete_state = complete_state
204+
req.is_disagg_context_transmission_state = transmission_state
205+
return req
206+
207+
208+
class TestDisaggTerminationGuard:
209+
"""Verify _handle_responses does not double-terminate DISAGG_CONTEXT_COMPLETE
210+
requests that were already cleaned up by _check_disagg_ctx_cache_transfer_status
211+
(nvbug/5961736)."""
212+
213+
def test_normal_path_skips_context_complete(self):
214+
"""Without partial reuse, CONTEXT_COMPLETE goes to stats only."""
215+
req = _make_request(complete_state=True, transmission_state=False)
216+
assert _classify_termination(req, False, False, 1) == "stats_only"
217+
218+
def test_normal_path_skips_transmission_in_progress(self):
219+
"""Without partial reuse, TRANS_IN_PROGRESS is skipped (still in flight)."""
220+
req = _make_request(complete_state=False, transmission_state=True)
221+
assert _classify_termination(req, False, False, 1) == "skip"
222+
223+
def test_normal_path_terminates_regular_request(self):
224+
"""Without partial reuse, a normal finished request is terminated."""
225+
req = _make_request(complete_state=False, transmission_state=False)
226+
assert _classify_termination(req, False, False, 1) == "terminate"
227+
228+
def test_partial_reuse_terminates_non_complete(self):
229+
"""With partial reuse, non-CONTEXT_COMPLETE requests are terminated."""
230+
for complete, transmission in [(False, True), (False, False)]:
231+
req = _make_request(complete, transmission)
232+
assert _classify_termination(req, True, False, 1) == "terminate"
233+
234+
def test_partial_reuse_skips_context_complete(self):
235+
"""With partial reuse, CONTEXT_COMPLETE still goes to stats only."""
236+
req = _make_request(complete_state=True, transmission_state=False)
237+
assert _classify_termination(req, True, False, 1) == "stats_only"
238+
239+
def test_partial_reuse_disabled_by_vswa(self):
240+
"""VSWA disables partial reuse path, falling back to normal logic."""
241+
req = _make_request(complete_state=True, transmission_state=False)
242+
assert _classify_termination(req, True, True, 1) == "stats_only"
243+
244+
def test_partial_reuse_disabled_by_pp(self):
245+
"""PP > 1 disables partial reuse path, falling back to normal logic."""
246+
req = _make_request(complete_state=True, transmission_state=False)
247+
assert _classify_termination(req, True, False, 2) == "stats_only"

0 commit comments

Comments
 (0)