Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2596,6 +2596,31 @@ def _commit_kv_cache_stats(self,
if self._is_kv_manager_v2:
self.kv_cache_manager.commit_scheduled_kv_cache_stats(
scheduled_batch)

@staticmethod
def _sync_disagg_generation_trans_complete_draft_tokens(
requests: Iterable[LlmRequest]) -> None:
for request in requests:
if not getattr(request,
"is_disagg_generation_transmission_complete", False):
continue

context_phase_params = request.context_phase_params
if context_phase_params is None:
continue

draft_tokens = context_phase_params.draft_tokens
request.py_draft_tokens = [] if draft_tokens is None else list(
draft_tokens)
request.draft_tokens = request.py_draft_tokens
request.py_draft_pages_allocated = len(request.py_draft_tokens)
Comment on lines +2600 to +2616

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Use the same draft-token resolver before scheduling and transmission-complete setup.

Line 2630 treats None/empty ctx drafts as zero scheduled draft tokens, but Line 4478 later turns that same no-draft case into max_total_draft_tokens dummy drafts when model_engine.enable_spec_decode is true. That can undercount batch capacity before the request prepares more decode tokens. Move this sync after the current spec-decode decision and share the same resolver with _prepare_disagg_gen_transmission_complete.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/pyexecutor/py_executor.py` around lines 2617 - 2633, The
draft-token sync in _sync_disagg_generation_trans_complete_draft_tokens is using
a different draft count rule than _prepare_disagg_gen_transmission_complete,
which can undercount capacity when spec decode is enabled. Update the
transmission-complete sync to use the same draft-token resolver/logic as
_prepare_disagg_gen_transmission_complete, and move this sync so it runs after
the spec-decode decision is applied. Keep the behavior aligned for empty or None
context drafts and ensure py_draft_tokens, draft_tokens, and
py_draft_pages_allocated are derived from the shared resolution path.


@staticmethod
def _get_generation_num_draft_tokens(request: LlmRequest) -> int:
py_draft_tokens = getattr(request, "py_draft_tokens", None)
if py_draft_tokens is None:
return request.num_draft_tokens
return max(len(py_draft_tokens), request.num_draft_tokens)

def _prepare_and_schedule_batch(self):
new_requests = self._fetch_and_activate_new_requests()
Expand All @@ -2606,6 +2631,8 @@ def _prepare_and_schedule_batch(self):
self._check_disagg_ctx_schedulable_status(new_requests)
self._check_disagg_gen_transfer_status()
self._check_kv_transfer_timeout()
self._sync_disagg_generation_trans_complete_draft_tokens(
self.active_requests)
Comment on lines +2634 to +2635

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

Mirror this sync in the PP scheduling path.

Line 2651 only covers _prepare_and_schedule_batch; _executor_loop_pp bypasses that helper and calls _pp_schedule_and_propagate() / _schedule() directly, so pp_size > 1 disagg generation requests can still schedule with stale draft-token fields.

Suggested placement in `_executor_loop_pp`
 if self.kv_cache_transceiver:
     self._check_disagg_ctx_schedulable_status(new_requests)
     self._check_disagg_gen_transfer_status()
+    self._sync_disagg_generation_trans_complete_draft_tokens(
+        self.active_requests)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/pyexecutor/py_executor.py` around lines 2651 - 2652, The
PP executor path is missing the same disagg generation transition sync that
`_prepare_and_schedule_batch` already performs, so `pp_size > 1` can schedule
requests with stale draft-token state. Add the
`_sync_disagg_generation_trans_complete_draft_tokens` call in
`_executor_loop_pp` before it invokes `_pp_schedule_and_propagate()` or
`_schedule()`, mirroring the existing behavior in `_prepare_and_schedule_batch`
and using `self.active_requests` there as the source.


iter_stats = None
if self.enable_iter_perf_stats:
Expand Down Expand Up @@ -4016,8 +4043,9 @@ def _compute_scheduled_tokens(context_requests, generation_requests):
else:
compute = max(1, remaining - reusable_in_chunk)
num_scheduled_ctx_tokens += compute
num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens
for gen_req in generation_requests)
num_scheduled_gen_tokens = sum(
1 + PyExecutor._get_generation_num_draft_tokens(gen_req)
for gen_req in generation_requests)
return num_scheduled_ctx_tokens + num_scheduled_gen_tokens

def _waiting_requests(self, context_requests: list[LlmRequest],
Expand Down Expand Up @@ -4441,7 +4469,10 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
ctx_draft_tokens = [
0
] * self.model_engine.max_total_draft_tokens
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
req.py_draft_tokens = [] if ctx_draft_tokens is None else list(
ctx_draft_tokens)
req.draft_tokens = req.py_draft_tokens
req.py_draft_pages_allocated = len(req.py_draft_tokens)
beam_width = req.py_beam_width
if not self._update_sampler_state_for_disagg_gen_request(
req, beam_width, first_gen_tokens):
Expand Down
77 changes: 68 additions & 9 deletions tensorrt_llm/serve/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
import json
import os
import time
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -676,16 +677,20 @@ class BlockHashMixin:

def _init_block_hashing(self,
tokens_per_block: int = 32,
custom_tokenizer: Optional[str] = None):
custom_tokenizer: Optional[str] = None,
use_harmony: Optional[bool] = None) -> None:
env_tokens_per_block = os.environ.get(
"TRTLLM_KVCACHE_AWARE_ROUTER_HASH_TOKENS_PER_BLOCK")
if env_tokens_per_block is not None:
tokens_per_block = int(env_tokens_per_block)
self._tokens_per_block = tokens_per_block
self._tokenizers: dict = {}
self._model_types: dict[str, Optional[str]] = {}
self._custom_tokenizer = custom_tokenizer
self._use_harmony = use_harmony
logger.info(f"BlockHashMixin: tokens_per_block={self._tokens_per_block}"
f", custom_tokenizer={self._custom_tokenizer}")
f", custom_tokenizer={self._custom_tokenizer}"
f", use_harmony={self._use_harmony}")

def _get_tokenizer(self, model: str):
if model not in self._tokenizers:
Expand All @@ -705,12 +710,69 @@ def _get_tokenizer(self, model: str):
model, trust_remote_code=True).tokenizer
return self._tokenizers[model]

def _get_model_type(self, model: str) -> Optional[str]:
if model not in self._model_types:
model_type = None
normalized_model = model.lower().replace("_", "-")
if "gpt-oss" in normalized_model or "gptoss" in normalized_model:
model_type = "gpt_oss"
else:
config_path = os.path.join(model, "config.json")
if os.path.isfile(config_path):
try:
with open(config_path, encoding="utf-8") as config_file:
config = json.load(config_file)
if isinstance(config, dict):
raw_model_type = config.get("model_type")
if isinstance(raw_model_type, str):
model_type = raw_model_type
except (OSError, json.JSONDecodeError) as e:
logger.debug(
"BlockHashMixin: failed to read model config for "
f"{model}: {e}")
self._model_types[model] = model_type
return self._model_types[model]

def _uses_harmony_tokenization(self,
request: ChatCompletionRequest) -> bool:
if self._use_harmony is not None:
return self._use_harmony
return self._get_model_type(request.model) == "gpt_oss"
Comment on lines +736 to +740

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win

Honor the server-side Harmony disable gate.

OpenAIServer resolves DISABLE_HARMONY_ADAPTER=1 to use_harmony=False, but this default router inference still enables Harmony for GPT-OSS models. That makes router KV hashes use Harmony tokens while the worker renders the non-Harmony chat template. Consider applying the same env gate here or always passing the server’s resolved use_harmony value into KvCacheAwareRouter.

Suggested alignment
 def _uses_harmony_tokenization(self,
                                request: ChatCompletionRequest) -> bool:
+    if os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1":
+        return False
     if self._use_harmony is not None:
         return self._use_harmony
     return self._get_model_type(request.model) == "gpt_oss"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _uses_harmony_tokenization(self,
request: ChatCompletionRequest) -> bool:
if self._use_harmony is not None:
return self._use_harmony
return self._get_model_type(request.model) == "gpt_oss"
def _uses_harmony_tokenization(self,
request: ChatCompletionRequest) -> bool:
if os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1":
return False
if self._use_harmony is not None:
return self._use_harmony
return self._get_model_type(request.model) == "gpt_oss"
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/serve/router.py` around lines 736 - 740, The Harmony
tokenization decision in `_uses_harmony_tokenization` is still defaulting
GPT-OSS models to Harmony even when the server has disabled it. Update
`KvCacheAwareRouter` to use the same resolved `use_harmony` value that
`OpenAIServer` derives from `DISABLE_HARMONY_ADAPTER`, or otherwise apply that
env gate here, so router KV hashing and worker chat template rendering stay
aligned.


@staticmethod
def _tool_dicts(
request: ChatCompletionRequest
) -> Optional[list[dict[str, object]]]:
if request.tools is None:
return None
return [tool.model_dump() for tool in request.tools]

def _tokenize_harmony_chat(
self, request: ChatCompletionRequest) -> list[list[int]]:
from tensorrt_llm.serve import harmony_adapter

tools = self._tool_dicts(request) if request.tools else None
result = harmony_adapter.get_harmony_adapter().openai_to_harmony_tokens(
request.messages,
tools,
reasoning_effort=harmony_adapter.maybe_transform_reasoning_effort(
request.reasoning_effort),
tool_choice=request.tool_choice,
)
return [result]

def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
# Handle ChatCompletionRequest (has messages, not prompt)
if isinstance(request, ChatCompletionRequest):
if request.prompt_token_ids is not None:
return [request.prompt_token_ids]
if self._uses_harmony_tokenization(request):
return self._tokenize_harmony_chat(request)
tokenizer = self._get_tokenizer(request.model)
# Forward tool schemas and chat-template flags so router hashes use
# the same rendered prompt as the worker-side tokenizer.
chat_template_kwargs = dict(request.chat_template_kwargs or {})
chat_template_kwargs["tools"] = self._tool_dicts(request)
result = tokenizer.apply_chat_template(
[
msg if isinstance(msg, dict) else dict(msg)
Expand All @@ -719,14 +781,13 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
add_generation_prompt=request.add_generation_prompt,
tokenize=True,
return_dict=False,
**chat_template_kwargs,
)
# Some custom tokenizers (e.g. DeepseekV32Tokenizer) return a
# string from apply_chat_template even with tokenize=True.
# Encode to token IDs if needed.
if isinstance(result, str):
result = tokenizer.encode(result, add_special_tokens=False)
# Set prompt_token_ids so the worker server skips re-tokenization
request.prompt_token_ids = result
return [result]

# Handle CompletionRequest (has prompt)
Expand All @@ -742,10 +803,6 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:

tokenizer = self._get_tokenizer(request.model)
token_lists = [tokenizer(prompt)["input_ids"] for prompt in prompts]
# Replace string prompts with token IDs so the worker server
# skips re-tokenization
request.prompt = (token_lists
if len(token_lists) > 1 else token_lists[0])
return token_lists

def _compute_block_hashes(self,
Expand Down Expand Up @@ -799,10 +856,12 @@ def __init__(self,
max_batch_size: int = 64,
tokens_per_block: int = 32,
custom_tokenizer: Optional[str] = None,
use_harmony: Optional[bool] = None,
**kwargs):
super().__init__(server_role, servers, metadata_server_cfg,
metadata_server, **kwargs)
self._init_block_hashing(tokens_per_block, custom_tokenizer)
self._init_block_hashing(tokens_per_block, custom_tokenizer,
use_harmony)
self._init_load_balancing(servers, use_tokens)
# TODO: use max_num_tokens? per server?
self._max_batch_size = max_batch_size
Expand Down
46 changes: 45 additions & 1 deletion tests/unittest/_torch/executor/test_py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,23 @@ def _make_ctx_request(
return req


def _make_gen_request(num_draft_tokens=0):
def _make_gen_request(num_draft_tokens: int = 0) -> Mock:
"""Helper to create a mock generation request."""
req = Mock()
req.num_draft_tokens = num_draft_tokens
req.py_draft_tokens = None
req.is_disagg_generation_transmission_complete = False
return req


def _make_disagg_trans_complete_request(draft_tokens: list[int] | None) -> Mock:
req = Mock()
req.is_disagg_generation_transmission_complete = True
req.context_phase_params = Mock(draft_tokens=draft_tokens)
req.py_draft_tokens = []
req.draft_tokens = []
req.py_draft_pages_allocated = 0
req.num_draft_tokens = 0
return req


Expand Down Expand Up @@ -362,6 +375,37 @@ def test_generation_tokens(self):
gen = [_make_gen_request(3), _make_gen_request(0)]
assert PyExecutor._compute_scheduled_tokens([], gen) == (1 + 3) + (1 + 0)

def test_disagg_trans_complete_draft_tokens_are_scheduler_visible(self) -> None:
gen = [_make_gen_request(3) for _ in range(127)]
trans_complete = _make_disagg_trans_complete_request([11, 12, 13])
gen.append(trans_complete)

assert PyExecutor._compute_scheduled_tokens([], gen) == 127 * 4 + 1

PyExecutor._sync_disagg_generation_trans_complete_draft_tokens(gen)

assert trans_complete.py_draft_tokens == [11, 12, 13]
assert trans_complete.draft_tokens == [11, 12, 13]
assert trans_complete.py_draft_pages_allocated == 3
assert PyExecutor._compute_scheduled_tokens([], gen) == 128 * 4

def test_disagg_trans_complete_missing_draft_tokens_are_scheduler_visible(self) -> None:
trans_complete = _make_disagg_trans_complete_request(None)
PyExecutor._sync_disagg_generation_trans_complete_draft_tokens([trans_complete])

assert trans_complete.py_draft_tokens == []
assert trans_complete.draft_tokens == []
assert trans_complete.py_draft_pages_allocated == 0
assert PyExecutor._compute_scheduled_tokens([], [trans_complete]) == 1

def test_sync_disagg_draft_tokens_ignores_regular_generation_requests(self) -> None:
gen = _make_gen_request(3)

PyExecutor._sync_disagg_generation_trans_complete_draft_tokens([gen])

assert gen.py_draft_tokens is None
assert PyExecutor._compute_scheduled_tokens([], [gen]) == 4

def test_mixed_context_and_generation(self):
"""Combined context (with chunk-shift) and generation tokens."""
# Non-last chunk: compute = 25
Expand Down
Loading
Loading