Skip to content

Commit 3081468

Browse files
committed
[None][fix] Align GPTOSS router tokenization and disagg draft scheduling
Use tool-aware chat template and Harmony tokenization for KV-cache-aware router hashes without mutating forwarded OpenAI requests. Sync disaggregated generation draft tokens from context-phase params before scheduling so batch capacity accounting sees transferred draft tokens. Add unit coverage for router/server tokenization parity and disagg draft-token scheduler accounting. Signed-off-by: Simeng Liu <simengl@nvidia.com>
1 parent 8412a17 commit 3081468

4 files changed

Lines changed: 507 additions & 14 deletions

File tree

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2614,6 +2614,31 @@ def _prefetch_for_context_requests(self) -> None:
26142614
if candidates:
26152615
self.kv_cache_manager.prefetch_for_context_tokens(candidates)
26162616

2617+
@staticmethod
2618+
def _sync_disagg_generation_trans_complete_draft_tokens(
2619+
requests: Iterable[LlmRequest]) -> None:
2620+
for request in requests:
2621+
if not getattr(request,
2622+
"is_disagg_generation_transmission_complete", False):
2623+
continue
2624+
2625+
context_phase_params = request.context_phase_params
2626+
if context_phase_params is None:
2627+
continue
2628+
2629+
draft_tokens = context_phase_params.draft_tokens
2630+
request.py_draft_tokens = [] if draft_tokens is None else list(
2631+
draft_tokens)
2632+
request.draft_tokens = request.py_draft_tokens
2633+
request.py_draft_pages_allocated = len(request.py_draft_tokens)
2634+
2635+
@staticmethod
2636+
def _get_generation_num_draft_tokens(request: LlmRequest) -> int:
2637+
py_draft_tokens = getattr(request, "py_draft_tokens", None)
2638+
if py_draft_tokens is None:
2639+
return request.num_draft_tokens
2640+
return max(len(py_draft_tokens), request.num_draft_tokens)
2641+
26172642
def _prepare_and_schedule_batch(self):
26182643
new_requests = self._fetch_and_activate_new_requests()
26192644
if self.should_stop_processing:
@@ -2623,6 +2648,8 @@ def _prepare_and_schedule_batch(self):
26232648
self._check_disagg_ctx_schedulable_status(new_requests)
26242649
self._check_disagg_gen_transfer_status()
26252650
self._check_kv_transfer_timeout()
2651+
self._sync_disagg_generation_trans_complete_draft_tokens(
2652+
self.active_requests)
26262653

26272654
iter_stats = None
26282655
if self.enable_iter_perf_stats:
@@ -4030,8 +4057,9 @@ def _compute_scheduled_tokens(context_requests, generation_requests):
40304057
else:
40314058
compute = max(1, remaining - reusable_in_chunk)
40324059
num_scheduled_ctx_tokens += compute
4033-
num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens
4034-
for gen_req in generation_requests)
4060+
num_scheduled_gen_tokens = sum(
4061+
1 + PyExecutor._get_generation_num_draft_tokens(gen_req)
4062+
for gen_req in generation_requests)
40354063
return num_scheduled_ctx_tokens + num_scheduled_gen_tokens
40364064

40374065
def _waiting_requests(self, context_requests: list[LlmRequest],
@@ -4455,7 +4483,10 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
44554483
ctx_draft_tokens = [
44564484
0
44574485
] * self.model_engine.max_total_draft_tokens
4458-
req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens
4486+
req.py_draft_tokens = [] if ctx_draft_tokens is None else list(
4487+
ctx_draft_tokens)
4488+
req.draft_tokens = req.py_draft_tokens
4489+
req.py_draft_pages_allocated = len(req.py_draft_tokens)
44594490
beam_width = req.py_beam_width
44604491
if not self._update_sampler_state_for_disagg_gen_request(
44614492
req, beam_width, first_gen_tokens):

tensorrt_llm/serve/router.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import json
1617
import os
1718
import time
1819
from abc import ABC, abstractmethod
@@ -676,16 +677,20 @@ class BlockHashMixin:
676677

677678
def _init_block_hashing(self,
678679
tokens_per_block: int = 32,
679-
custom_tokenizer: Optional[str] = None):
680+
custom_tokenizer: Optional[str] = None,
681+
use_harmony: Optional[bool] = None) -> None:
680682
env_tokens_per_block = os.environ.get(
681683
"TRTLLM_KVCACHE_AWARE_ROUTER_HASH_TOKENS_PER_BLOCK")
682684
if env_tokens_per_block is not None:
683685
tokens_per_block = int(env_tokens_per_block)
684686
self._tokens_per_block = tokens_per_block
685687
self._tokenizers: dict = {}
688+
self._model_types: dict[str, Optional[str]] = {}
686689
self._custom_tokenizer = custom_tokenizer
690+
self._use_harmony = use_harmony
687691
logger.info(f"BlockHashMixin: tokens_per_block={self._tokens_per_block}"
688-
f", custom_tokenizer={self._custom_tokenizer}")
692+
f", custom_tokenizer={self._custom_tokenizer}"
693+
f", use_harmony={self._use_harmony}")
689694

690695
def _get_tokenizer(self, model: str):
691696
if model not in self._tokenizers:
@@ -705,12 +710,69 @@ def _get_tokenizer(self, model: str):
705710
model, trust_remote_code=True).tokenizer
706711
return self._tokenizers[model]
707712

713+
def _get_model_type(self, model: str) -> Optional[str]:
714+
if model not in self._model_types:
715+
model_type = None
716+
normalized_model = model.lower().replace("_", "-")
717+
if "gpt-oss" in normalized_model or "gptoss" in normalized_model:
718+
model_type = "gpt_oss"
719+
else:
720+
config_path = os.path.join(model, "config.json")
721+
if os.path.isfile(config_path):
722+
try:
723+
with open(config_path, encoding="utf-8") as config_file:
724+
config = json.load(config_file)
725+
if isinstance(config, dict):
726+
raw_model_type = config.get("model_type")
727+
if isinstance(raw_model_type, str):
728+
model_type = raw_model_type
729+
except (OSError, json.JSONDecodeError) as e:
730+
logger.debug(
731+
"BlockHashMixin: failed to read model config for "
732+
f"{model}: {e}")
733+
self._model_types[model] = model_type
734+
return self._model_types[model]
735+
736+
def _uses_harmony_tokenization(self,
737+
request: ChatCompletionRequest) -> bool:
738+
if self._use_harmony is not None:
739+
return self._use_harmony
740+
return self._get_model_type(request.model) == "gpt_oss"
741+
742+
@staticmethod
743+
def _tool_dicts(
744+
request: ChatCompletionRequest
745+
) -> Optional[list[dict[str, object]]]:
746+
if request.tools is None:
747+
return None
748+
return [tool.model_dump() for tool in request.tools]
749+
750+
def _tokenize_harmony_chat(
751+
self, request: ChatCompletionRequest) -> list[list[int]]:
752+
from tensorrt_llm.serve import harmony_adapter
753+
754+
tools = self._tool_dicts(request) if request.tools else None
755+
result = harmony_adapter.get_harmony_adapter().openai_to_harmony_tokens(
756+
request.messages,
757+
tools,
758+
reasoning_effort=harmony_adapter.maybe_transform_reasoning_effort(
759+
request.reasoning_effort),
760+
tool_choice=request.tool_choice,
761+
)
762+
return [result]
763+
708764
def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
709765
# Handle ChatCompletionRequest (has messages, not prompt)
710766
if isinstance(request, ChatCompletionRequest):
711767
if request.prompt_token_ids is not None:
712768
return [request.prompt_token_ids]
769+
if self._uses_harmony_tokenization(request):
770+
return self._tokenize_harmony_chat(request)
713771
tokenizer = self._get_tokenizer(request.model)
772+
# Forward tool schemas and chat-template flags so router hashes use
773+
# the same rendered prompt as the worker-side tokenizer.
774+
chat_template_kwargs = dict(request.chat_template_kwargs or {})
775+
chat_template_kwargs["tools"] = self._tool_dicts(request)
714776
result = tokenizer.apply_chat_template(
715777
[
716778
msg if isinstance(msg, dict) else dict(msg)
@@ -719,14 +781,13 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
719781
add_generation_prompt=request.add_generation_prompt,
720782
tokenize=True,
721783
return_dict=False,
784+
**chat_template_kwargs,
722785
)
723786
# Some custom tokenizers (e.g. DeepseekV32Tokenizer) return a
724787
# string from apply_chat_template even with tokenize=True.
725788
# Encode to token IDs if needed.
726789
if isinstance(result, str):
727790
result = tokenizer.encode(result, add_special_tokens=False)
728-
# Set prompt_token_ids so the worker server skips re-tokenization
729-
request.prompt_token_ids = result
730791
return [result]
731792

732793
# Handle CompletionRequest (has prompt)
@@ -742,10 +803,6 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
742803

743804
tokenizer = self._get_tokenizer(request.model)
744805
token_lists = [tokenizer(prompt)["input_ids"] for prompt in prompts]
745-
# Replace string prompts with token IDs so the worker server
746-
# skips re-tokenization
747-
request.prompt = (token_lists
748-
if len(token_lists) > 1 else token_lists[0])
749806
return token_lists
750807

751808
def _compute_block_hashes(self,
@@ -799,10 +856,12 @@ def __init__(self,
799856
max_batch_size: int = 64,
800857
tokens_per_block: int = 32,
801858
custom_tokenizer: Optional[str] = None,
859+
use_harmony: Optional[bool] = None,
802860
**kwargs):
803861
super().__init__(server_role, servers, metadata_server_cfg,
804862
metadata_server, **kwargs)
805-
self._init_block_hashing(tokens_per_block, custom_tokenizer)
863+
self._init_block_hashing(tokens_per_block, custom_tokenizer,
864+
use_harmony)
806865
self._init_load_balancing(servers, use_tokens)
807866
# TODO: use max_num_tokens? per server?
808867
self._max_batch_size = max_batch_size

tests/unittest/_torch/executor/test_py_executor.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,23 @@ def _make_ctx_request(
273273
return req
274274

275275

276-
def _make_gen_request(num_draft_tokens=0):
276+
def _make_gen_request(num_draft_tokens: int = 0) -> Mock:
277277
"""Helper to create a mock generation request."""
278278
req = Mock()
279279
req.num_draft_tokens = num_draft_tokens
280+
req.py_draft_tokens = None
281+
req.is_disagg_generation_transmission_complete = False
282+
return req
283+
284+
285+
def _make_disagg_trans_complete_request(draft_tokens: list[int] | None) -> Mock:
286+
req = Mock()
287+
req.is_disagg_generation_transmission_complete = True
288+
req.context_phase_params = Mock(draft_tokens=draft_tokens)
289+
req.py_draft_tokens = []
290+
req.draft_tokens = []
291+
req.py_draft_pages_allocated = 0
292+
req.num_draft_tokens = 0
280293
return req
281294

282295

@@ -362,6 +375,37 @@ def test_generation_tokens(self):
362375
gen = [_make_gen_request(3), _make_gen_request(0)]
363376
assert PyExecutor._compute_scheduled_tokens([], gen) == (1 + 3) + (1 + 0)
364377

378+
def test_disagg_trans_complete_draft_tokens_are_scheduler_visible(self) -> None:
379+
gen = [_make_gen_request(3) for _ in range(127)]
380+
trans_complete = _make_disagg_trans_complete_request([11, 12, 13])
381+
gen.append(trans_complete)
382+
383+
assert PyExecutor._compute_scheduled_tokens([], gen) == 127 * 4 + 1
384+
385+
PyExecutor._sync_disagg_generation_trans_complete_draft_tokens(gen)
386+
387+
assert trans_complete.py_draft_tokens == [11, 12, 13]
388+
assert trans_complete.draft_tokens == [11, 12, 13]
389+
assert trans_complete.py_draft_pages_allocated == 3
390+
assert PyExecutor._compute_scheduled_tokens([], gen) == 128 * 4
391+
392+
def test_disagg_trans_complete_missing_draft_tokens_are_scheduler_visible(self) -> None:
393+
trans_complete = _make_disagg_trans_complete_request(None)
394+
PyExecutor._sync_disagg_generation_trans_complete_draft_tokens([trans_complete])
395+
396+
assert trans_complete.py_draft_tokens == []
397+
assert trans_complete.draft_tokens == []
398+
assert trans_complete.py_draft_pages_allocated == 0
399+
assert PyExecutor._compute_scheduled_tokens([], [trans_complete]) == 1
400+
401+
def test_sync_disagg_draft_tokens_ignores_regular_generation_requests(self) -> None:
402+
gen = _make_gen_request(3)
403+
404+
PyExecutor._sync_disagg_generation_trans_complete_draft_tokens([gen])
405+
406+
assert gen.py_draft_tokens is None
407+
assert PyExecutor._compute_scheduled_tokens([], [gen]) == 4
408+
365409
def test_mixed_context_and_generation(self):
366410
"""Combined context (with chunk-shift) and generation tokens."""
367411
# Non-last chunk: compute = 25

0 commit comments

Comments
 (0)