Skip to content

Commit 163be83

Browse files
authored
[https://nvbugs/6223556][fix] Propagate gen-first ctx usage via aux buffer to postproc (#15246)
Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent f3b718a commit 163be83

6 files changed

Lines changed: 34 additions & 10 deletions

File tree

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,9 @@ def __init__(self,
578578
self._py_result = py_result
579579
self.is_final = is_final
580580
self.cached_tokens = 0
581+
# Context-worker usage for gen-first disagg, delivered via the
582+
# KV-transfer aux buffer (see _maybe_attach_ctx_usage).
583+
self.ctx_usage = None
581584
# Time breakdown metrics for performance analysis
582585
# Contains: step_metrics (list), ctx_gpu_forward_time (float), ctx_gpu_sample_time (float)
583586
self.time_breakdown_metrics = time_breakdown_metrics

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def _end_transfer_and_maybe_terminate(self, request: LlmRequest):
789789
response = request.create_response(False, self.dist.rank)
790790
if response:
791791
response.result.cached_tokens = request.cached_tokens
792+
self._maybe_attach_ctx_usage(request, response)
792793
# Buffer the response instead of enqueueing immediately.
793794
# With ADP, _enqueue_responses does a tp_gather collective.
794795
# Calling it here would deadlock because only the owning DP
@@ -4321,6 +4322,15 @@ def fail_request(message: str) -> bool:
43214322
cum_log_probs[seq_slot, :beam_width].copy_(values)
43224323
return True
43234324

4325+
@staticmethod
4326+
def _maybe_attach_ctx_usage(request: LlmRequest, response):
4327+
"""Surface gen-first ctx usage (delivered via the KV-transfer aux
4328+
buffer in RxSession.unpack_aux) onto the response so the postprocessor
4329+
adopts the context-side prompt/cached token accounting."""
4330+
disagg_params = request.py_disaggregated_params
4331+
if disagg_params is not None and disagg_params.ctx_usage is not None:
4332+
response.result.ctx_usage = disagg_params.ctx_usage
4333+
43244334
def _maybe_prepend_logprobs_and_logits(self, req, beam_width):
43254335
"""Prepend logprobs and generation logits for first_gen_tokens
43264336
if transferred from prefill."""
@@ -4980,6 +4990,7 @@ def _emit_first_token_responses(self, prev_scheduled_requests):
49804990
if response is None:
49814991
continue
49824992
response.result.cached_tokens = request.cached_tokens
4993+
self._maybe_attach_ctx_usage(request, response)
49834994
if logits_snapshot is not None:
49844995
response.result.generation_logits = logits_snapshot
49854996
new_responses.append((request.py_request_id, response))
@@ -5067,6 +5078,7 @@ def _handle_responses(self, emit_first_iter: bool = True):
50675078
if response:
50685079
request_done = request.is_finished
50695080
response.result.cached_tokens = request.cached_tokens
5081+
self._maybe_attach_ctx_usage(request, response)
50705082
response.result.per_pos_drafted = request.py_per_pos_drafted
50715083
response.result.per_pos_accepted = request.py_per_pos_accepted
50725084
new_responses.append((req_id, response))

tensorrt_llm/executor/result.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,17 @@ def _handle_response(self,
498498
self.per_pos_accepted = getattr(response_result, 'per_pos_accepted',
499499
None)
500500
self.avg_decoded_tokens_per_iter = response_result.avg_decoded_tokens_per_iter
501+
# Expose gen-first ctx usage so the postprocessor
502+
# (_ctx_usage_from_outputs) can adopt the context-side accounting.
503+
# ctx_usage only exists on the Python LlmResult wrapper; the raw C++
504+
# bindings.executor.Result (non-disagg / benchmark path) does not
505+
# have it, so fall back to None as with cached_tokens above.
506+
ctx_usage = getattr(response_result, 'ctx_usage', None)
507+
if ctx_usage is not None:
508+
self._disaggregated_params = dataclasses.replace(
509+
self._disaggregated_params or DisaggregatedParams(),
510+
ctx_usage=ctx_usage,
511+
)
501512
if context_phase_params is not None:
502513
existing_disagg_params = self.disaggregated_params
503514
# Use `replace` to preserve things like `mrope_position_ids_handle` and

tests/integration/defs/disaggregated/test_configs/disagg_config_overlap_gen_first.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ context_servers:
1212
tensor_parallel_size: 1
1313
pipeline_parallel_size: 1
1414
kv_cache_config:
15-
enable_block_reuse: False
15+
enable_block_reuse: True
1616
free_gpu_memory_fraction: 0.2
17-
enable_partial_reuse: False
17+
enable_partial_reuse: True
1818

1919
cache_transceiver_config:
2020
backend: DEFAULT
@@ -29,9 +29,9 @@ generation_servers:
2929
max_num_tokens: 4096
3030
max_seq_len: 4096
3131
kv_cache_config:
32-
enable_block_reuse: False
32+
enable_block_reuse: True
3333
free_gpu_memory_fraction: 0.2
34-
enable_partial_reuse: False
34+
enable_partial_reuse: True
3535
cache_transceiver_config:
3636
backend: DEFAULT
3737
transceiver_runtime: PYTHON

tests/integration/defs/disaggregated/test_configs/disagg_config_overlap_gen_first_pp4.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ context_servers:
1212
tensor_parallel_size: 1
1313
pipeline_parallel_size: 4
1414
kv_cache_config:
15-
enable_block_reuse: False
15+
enable_block_reuse: True
1616
free_gpu_memory_fraction: 0.2
17-
enable_partial_reuse: False
17+
enable_partial_reuse: True
1818

1919
cache_transceiver_config:
2020
backend: DEFAULT
@@ -29,9 +29,9 @@ generation_servers:
2929
max_num_tokens: 4096
3030
max_seq_len: 4096
3131
kv_cache_config:
32-
enable_block_reuse: False
32+
enable_block_reuse: True
3333
free_gpu_memory_fraction: 0.2
34-
enable_partial_reuse: False
34+
enable_partial_reuse: True
3535
cache_transceiver_config:
3636
backend: DEFAULT
3737
transceiver_runtime: PYTHON

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,6 @@ disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1
128128
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_two_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/6162322)
129129
disaggregated/test_disaggregated.py::test_disaggregated_genbs1[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6162322)
130130
disaggregated/test_disaggregated.py::test_disaggregated_gpt_oss_120b_harmony[gpt_oss/gpt-oss-120b] SKIP (https://nvbugs/6245317)
131-
disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp1-TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6223556)
132-
disaggregated/test_disaggregated.py::test_disaggregated_overlap_gen_first[ctx_pp4-TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6223556)
133131
disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_llama_context_capacity[False-False-DeepSeek-V3-Lite-fp8/fp8] SKIP (https://nvbugs/6266302)
134132
disaggregated/test_workers.py::test_workers_conversation_router[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6162322)
135133
disaggregated/test_workers.py::test_workers_kv_cache_aware_router_deepseek_v3_lite_bf16[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/6162322)

0 commit comments

Comments
 (0)