Skip to content

Commit f3b718a

Browse files
[https://nvbugs/6245861][fix] Gate the two ID None-checks on finish_reason in _GEN_PENDING_FINISH_REASONS… (#14908)
Signed-off-by: tensorrt-cicd <90828364+tensorrt-cicd@users.noreply.github.com>
1 parent e45dda9 commit f3b718a

2 files changed

Lines changed: 41 additions & 15 deletions

File tree

tensorrt_llm/serve/openai_disagg_service.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@
4646
)
4747
from tensorrt_llm.serve.router import KvCacheAwareRouter, Router
4848

49+
# Finish reasons for which a GEN handoff is still pending; any other reason means
50+
# the CTX request already completed and the disagg KV-cache handoff was never set up.
51+
_GEN_PENDING_FINISH_REASONS = ("length", "not_finished")
52+
4953

5054
class OpenAIDisaggregatedService(OpenAIService):
5155
def __init__(
@@ -174,7 +178,7 @@ async def _send_disagg_request_ctx_first(
174178
return ctx_response
175179

176180
def _need_gen(self, response: UCompletionResponse) -> bool:
177-
if response and response.choices[0].finish_reason not in ["length", "not_finished"]:
181+
if response and response.choices[0].finish_reason not in _GEN_PENDING_FINISH_REASONS:
178182
del response.choices[0].disaggregated_params
179183
return False
180184
return True
@@ -384,24 +388,28 @@ async def _on_worker_event(self, worker_info: WorkerInfo, event_type: WatchEvent
384388
async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None:
385389
if ctx_response:
386390
for idx, choice in enumerate(ctx_response.choices):
387-
choice = ctx_response.choices[idx]
388391
if choice.disaggregated_params is None:
389392
raise ValueError(
390393
f"Context server choice {idx} did not return disaggregated params."
391394
f" finish_reason={choice.finish_reason!r}"
392395
)
393-
if choice.disaggregated_params.ctx_request_id is None:
394-
raise ValueError(
395-
f"Invalid disaggregated params: ctx_request_id is None for choice {idx}."
396-
f" finish_reason={choice.finish_reason!r},"
397-
f" disagg_request_id={choice.disaggregated_params.disagg_request_id!r}"
398-
)
399-
if choice.disaggregated_params.disagg_request_id is None:
400-
raise ValueError(
401-
f"Invalid disaggregated params: disagg_request_id is None for choice {idx}."
402-
f" finish_reason={choice.finish_reason!r},"
403-
f" ctx_request_id={choice.disaggregated_params.ctx_request_id!r}"
404-
)
396+
# A CTX request that finished early (e.g. EOS during prefill) never
397+
# sets up the KV-cache handoff, so ctx_request_id/disagg_request_id
398+
# stay None. Only enforce them when a GEN handoff is still pending --
399+
# mirroring _need_gen, which skips the handoff for these responses.
400+
if choice.finish_reason in _GEN_PENDING_FINISH_REASONS:
401+
if choice.disaggregated_params.ctx_request_id is None:
402+
raise ValueError(
403+
f"Invalid disaggregated params: ctx_request_id is None for choice {idx}."
404+
f" finish_reason={choice.finish_reason!r},"
405+
f" disagg_request_id={choice.disaggregated_params.disagg_request_id!r}"
406+
)
407+
if choice.disaggregated_params.disagg_request_id is None:
408+
raise ValueError(
409+
f"Invalid disaggregated params: disagg_request_id is None for choice {idx}."
410+
f" finish_reason={choice.finish_reason!r},"
411+
f" ctx_request_id={choice.disaggregated_params.ctx_request_id!r}"
412+
)
405413
return ctx_response
406414

407415
async def _send_disagg_request_gen_first(

tests/unittest/disaggregated/test_openai_disagg_service.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ async def test_missing_ctx_request_id_includes_disagg_id(self):
429429
@pytest.mark.asyncio
430430
async def test_missing_disagg_request_id_includes_ctx_id(self):
431431
svc = _make_service("context_first")
432-
resp = _make_completion_response("", finish_reason="stop", disagg_request_id=555)
432+
resp = _make_completion_response("", finish_reason="length", disagg_request_id=555)
433433
resp.choices[0].disaggregated_params.disagg_request_id = None
434434
resp.choices[0].disaggregated_params.ctx_request_id = 555
435435
with pytest.raises(ValueError, match=r"disagg_request_id is None.*555"):
@@ -442,6 +442,24 @@ async def test_valid_response_passes(self):
442442
result = await svc._verify_ctx_response(resp)
443443
assert result is resp
444444

445+
@pytest.mark.asyncio
446+
async def test_completed_response_with_null_ctx_request_id_passes(self):
447+
# CTX finished early (finish_reason='stop'): no GEN handoff was set up,
448+
# so ctx_request_id is None. The verifier must accept it (NVBug 6245861).
449+
svc = _make_service("context_first")
450+
resp = _make_completion_response("", finish_reason="stop", disagg_request_id=42)
451+
resp.choices[0].disaggregated_params.ctx_request_id = None
452+
result = await svc._verify_ctx_response(resp)
453+
assert result is resp
454+
455+
@pytest.mark.asyncio
456+
async def test_completed_response_with_null_disagg_request_id_passes(self):
457+
svc = _make_service("context_first")
458+
resp = _make_completion_response("", finish_reason="stop", disagg_request_id=42)
459+
resp.choices[0].disaggregated_params.disagg_request_id = None
460+
result = await svc._verify_ctx_response(resp)
461+
assert result is resp
462+
445463

446464
class TestFirstGenLogProbsSerializeRoundtrip:
447465
"""Roundtrip tests for _serialize/_deserialize_first_gen_log_probs."""

0 commit comments

Comments
 (0)