|
46 | 46 | ) |
47 | 47 | from tensorrt_llm.serve.router import KvCacheAwareRouter, Router |
48 | 48 |
|
| 49 | +# Finish reasons for which a GEN handoff is still pending; any other reason means |
| 50 | +# the CTX request already completed and the disagg KV-cache handoff was never set up. |
| 51 | +_GEN_PENDING_FINISH_REASONS = ("length", "not_finished") |
| 52 | + |
49 | 53 |
|
50 | 54 | class OpenAIDisaggregatedService(OpenAIService): |
51 | 55 | def __init__( |
@@ -174,7 +178,7 @@ async def _send_disagg_request_ctx_first( |
174 | 178 | return ctx_response |
175 | 179 |
|
176 | 180 | def _need_gen(self, response: UCompletionResponse) -> bool: |
177 | | - if response and response.choices[0].finish_reason not in ["length", "not_finished"]: |
| 181 | + if response and response.choices[0].finish_reason not in _GEN_PENDING_FINISH_REASONS: |
178 | 182 | del response.choices[0].disaggregated_params |
179 | 183 | return False |
180 | 184 | return True |
@@ -384,24 +388,28 @@ async def _on_worker_event(self, worker_info: WorkerInfo, event_type: WatchEvent |
384 | 388 | async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None: |
385 | 389 | if ctx_response: |
386 | 390 | for idx, choice in enumerate(ctx_response.choices): |
387 | | - choice = ctx_response.choices[idx] |
388 | 391 | if choice.disaggregated_params is None: |
389 | 392 | raise ValueError( |
390 | 393 | f"Context server choice {idx} did not return disaggregated params." |
391 | 394 | f" finish_reason={choice.finish_reason!r}" |
392 | 395 | ) |
393 | | - if choice.disaggregated_params.ctx_request_id is None: |
394 | | - raise ValueError( |
395 | | - f"Invalid disaggregated params: ctx_request_id is None for choice {idx}." |
396 | | - f" finish_reason={choice.finish_reason!r}," |
397 | | - f" disagg_request_id={choice.disaggregated_params.disagg_request_id!r}" |
398 | | - ) |
399 | | - if choice.disaggregated_params.disagg_request_id is None: |
400 | | - raise ValueError( |
401 | | - f"Invalid disaggregated params: disagg_request_id is None for choice {idx}." |
402 | | - f" finish_reason={choice.finish_reason!r}," |
403 | | - f" ctx_request_id={choice.disaggregated_params.ctx_request_id!r}" |
404 | | - ) |
| 396 | + # A CTX request that finished early (e.g. EOS during prefill) never |
| 397 | + # sets up the KV-cache handoff, so ctx_request_id/disagg_request_id |
| 398 | + # stay None. Only enforce them when a GEN handoff is still pending -- |
| 399 | + # mirroring _need_gen, which skips the handoff for these responses. |
| 400 | + if choice.finish_reason in _GEN_PENDING_FINISH_REASONS: |
| 401 | + if choice.disaggregated_params.ctx_request_id is None: |
| 402 | + raise ValueError( |
| 403 | + f"Invalid disaggregated params: ctx_request_id is None for choice {idx}." |
| 404 | + f" finish_reason={choice.finish_reason!r}," |
| 405 | + f" disagg_request_id={choice.disaggregated_params.disagg_request_id!r}" |
| 406 | + ) |
| 407 | + if choice.disaggregated_params.disagg_request_id is None: |
| 408 | + raise ValueError( |
| 409 | + f"Invalid disaggregated params: disagg_request_id is None for choice {idx}." |
| 410 | + f" finish_reason={choice.finish_reason!r}," |
| 411 | + f" ctx_request_id={choice.disaggregated_params.ctx_request_id!r}" |
| 412 | + ) |
405 | 413 | return ctx_response |
406 | 414 |
|
407 | 415 | async def _send_disagg_request_gen_first( |
|
0 commit comments