-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[None][fix] Align GPTOSS router tokenization and disagg draft scheduling #15605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2596,6 +2596,31 @@ def _commit_kv_cache_stats(self, | |
| if self._is_kv_manager_v2: | ||
| self.kv_cache_manager.commit_scheduled_kv_cache_stats( | ||
| scheduled_batch) | ||
|
|
||
| @staticmethod | ||
| def _sync_disagg_generation_trans_complete_draft_tokens( | ||
| requests: Iterable[LlmRequest]) -> None: | ||
| for request in requests: | ||
| if not getattr(request, | ||
| "is_disagg_generation_transmission_complete", False): | ||
| continue | ||
|
|
||
| context_phase_params = request.context_phase_params | ||
| if context_phase_params is None: | ||
| continue | ||
|
|
||
| draft_tokens = context_phase_params.draft_tokens | ||
| request.py_draft_tokens = [] if draft_tokens is None else list( | ||
| draft_tokens) | ||
| request.draft_tokens = request.py_draft_tokens | ||
| request.py_draft_pages_allocated = len(request.py_draft_tokens) | ||
|
|
||
| @staticmethod | ||
| def _get_generation_num_draft_tokens(request: LlmRequest) -> int: | ||
| py_draft_tokens = getattr(request, "py_draft_tokens", None) | ||
| if py_draft_tokens is None: | ||
| return request.num_draft_tokens | ||
| return max(len(py_draft_tokens), request.num_draft_tokens) | ||
|
|
||
| def _prepare_and_schedule_batch(self): | ||
| new_requests = self._fetch_and_activate_new_requests() | ||
|
|
@@ -2606,6 +2631,8 @@ def _prepare_and_schedule_batch(self): | |
| self._check_disagg_ctx_schedulable_status(new_requests) | ||
| self._check_disagg_gen_transfer_status() | ||
| self._check_kv_transfer_timeout() | ||
| self._sync_disagg_generation_trans_complete_draft_tokens( | ||
| self.active_requests) | ||
|
Comment on lines
+2634
to
+2635
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win Mirror this sync in the PP scheduling path. Line 2651 only covers Suggested placement in `_executor_loop_pp` if self.kv_cache_transceiver:
self._check_disagg_ctx_schedulable_status(new_requests)
self._check_disagg_gen_transfer_status()
+ self._sync_disagg_generation_trans_complete_draft_tokens(
+ self.active_requests)🤖 Prompt for AI Agents |
||
|
|
||
| iter_stats = None | ||
| if self.enable_iter_perf_stats: | ||
|
|
@@ -4016,8 +4043,9 @@ def _compute_scheduled_tokens(context_requests, generation_requests): | |
| else: | ||
| compute = max(1, remaining - reusable_in_chunk) | ||
| num_scheduled_ctx_tokens += compute | ||
| num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens | ||
| for gen_req in generation_requests) | ||
| num_scheduled_gen_tokens = sum( | ||
| 1 + PyExecutor._get_generation_num_draft_tokens(gen_req) | ||
| for gen_req in generation_requests) | ||
| return num_scheduled_ctx_tokens + num_scheduled_gen_tokens | ||
|
|
||
| def _waiting_requests(self, context_requests: list[LlmRequest], | ||
|
|
@@ -4441,7 +4469,10 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch): | |
| ctx_draft_tokens = [ | ||
| 0 | ||
| ] * self.model_engine.max_total_draft_tokens | ||
| req.py_draft_tokens = [] if ctx_draft_tokens is None else ctx_draft_tokens | ||
| req.py_draft_tokens = [] if ctx_draft_tokens is None else list( | ||
| ctx_draft_tokens) | ||
| req.draft_tokens = req.py_draft_tokens | ||
| req.py_draft_pages_allocated = len(req.py_draft_tokens) | ||
| beam_width = req.py_beam_width | ||
| if not self._update_sampler_state_for_disagg_gen_request( | ||
| req, beam_width, first_gen_tokens): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||
| import json | ||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||
| from abc import ABC, abstractmethod | ||||||||||||||||||||||||||
|
|
@@ -676,16 +677,20 @@ class BlockHashMixin: | |||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _init_block_hashing(self, | ||||||||||||||||||||||||||
| tokens_per_block: int = 32, | ||||||||||||||||||||||||||
| custom_tokenizer: Optional[str] = None): | ||||||||||||||||||||||||||
| custom_tokenizer: Optional[str] = None, | ||||||||||||||||||||||||||
| use_harmony: Optional[bool] = None) -> None: | ||||||||||||||||||||||||||
| env_tokens_per_block = os.environ.get( | ||||||||||||||||||||||||||
| "TRTLLM_KVCACHE_AWARE_ROUTER_HASH_TOKENS_PER_BLOCK") | ||||||||||||||||||||||||||
| if env_tokens_per_block is not None: | ||||||||||||||||||||||||||
| tokens_per_block = int(env_tokens_per_block) | ||||||||||||||||||||||||||
| self._tokens_per_block = tokens_per_block | ||||||||||||||||||||||||||
| self._tokenizers: dict = {} | ||||||||||||||||||||||||||
| self._model_types: dict[str, Optional[str]] = {} | ||||||||||||||||||||||||||
| self._custom_tokenizer = custom_tokenizer | ||||||||||||||||||||||||||
| self._use_harmony = use_harmony | ||||||||||||||||||||||||||
| logger.info(f"BlockHashMixin: tokens_per_block={self._tokens_per_block}" | ||||||||||||||||||||||||||
| f", custom_tokenizer={self._custom_tokenizer}") | ||||||||||||||||||||||||||
| f", custom_tokenizer={self._custom_tokenizer}" | ||||||||||||||||||||||||||
| f", use_harmony={self._use_harmony}") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _get_tokenizer(self, model: str): | ||||||||||||||||||||||||||
| if model not in self._tokenizers: | ||||||||||||||||||||||||||
|
|
@@ -705,12 +710,69 @@ def _get_tokenizer(self, model: str): | |||||||||||||||||||||||||
| model, trust_remote_code=True).tokenizer | ||||||||||||||||||||||||||
| return self._tokenizers[model] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _get_model_type(self, model: str) -> Optional[str]: | ||||||||||||||||||||||||||
| if model not in self._model_types: | ||||||||||||||||||||||||||
| model_type = None | ||||||||||||||||||||||||||
| normalized_model = model.lower().replace("_", "-") | ||||||||||||||||||||||||||
| if "gpt-oss" in normalized_model or "gptoss" in normalized_model: | ||||||||||||||||||||||||||
| model_type = "gpt_oss" | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| config_path = os.path.join(model, "config.json") | ||||||||||||||||||||||||||
| if os.path.isfile(config_path): | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| with open(config_path, encoding="utf-8") as config_file: | ||||||||||||||||||||||||||
| config = json.load(config_file) | ||||||||||||||||||||||||||
| if isinstance(config, dict): | ||||||||||||||||||||||||||
| raw_model_type = config.get("model_type") | ||||||||||||||||||||||||||
| if isinstance(raw_model_type, str): | ||||||||||||||||||||||||||
| model_type = raw_model_type | ||||||||||||||||||||||||||
| except (OSError, json.JSONDecodeError) as e: | ||||||||||||||||||||||||||
| logger.debug( | ||||||||||||||||||||||||||
| "BlockHashMixin: failed to read model config for " | ||||||||||||||||||||||||||
| f"{model}: {e}") | ||||||||||||||||||||||||||
| self._model_types[model] = model_type | ||||||||||||||||||||||||||
| return self._model_types[model] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _uses_harmony_tokenization(self, | ||||||||||||||||||||||||||
| request: ChatCompletionRequest) -> bool: | ||||||||||||||||||||||||||
| if self._use_harmony is not None: | ||||||||||||||||||||||||||
| return self._use_harmony | ||||||||||||||||||||||||||
| return self._get_model_type(request.model) == "gpt_oss" | ||||||||||||||||||||||||||
|
Comment on lines
+736
to
+740
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick win Honor the server-side Harmony disable gate.
Suggested alignment def _uses_harmony_tokenization(self,
request: ChatCompletionRequest) -> bool:
+ if os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1":
+ return False
if self._use_harmony is not None:
return self._use_harmony
return self._get_model_type(request.model) == "gpt_oss"📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||
| def _tool_dicts( | ||||||||||||||||||||||||||
| request: ChatCompletionRequest | ||||||||||||||||||||||||||
| ) -> Optional[list[dict[str, object]]]: | ||||||||||||||||||||||||||
| if request.tools is None: | ||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||
| return [tool.model_dump() for tool in request.tools] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _tokenize_harmony_chat( | ||||||||||||||||||||||||||
| self, request: ChatCompletionRequest) -> list[list[int]]: | ||||||||||||||||||||||||||
| from tensorrt_llm.serve import harmony_adapter | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| tools = self._tool_dicts(request) if request.tools else None | ||||||||||||||||||||||||||
| result = harmony_adapter.get_harmony_adapter().openai_to_harmony_tokens( | ||||||||||||||||||||||||||
| request.messages, | ||||||||||||||||||||||||||
| tools, | ||||||||||||||||||||||||||
| reasoning_effort=harmony_adapter.maybe_transform_reasoning_effort( | ||||||||||||||||||||||||||
| request.reasoning_effort), | ||||||||||||||||||||||||||
| tool_choice=request.tool_choice, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| return [result] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: | ||||||||||||||||||||||||||
| # Handle ChatCompletionRequest (has messages, not prompt) | ||||||||||||||||||||||||||
| if isinstance(request, ChatCompletionRequest): | ||||||||||||||||||||||||||
| if request.prompt_token_ids is not None: | ||||||||||||||||||||||||||
| return [request.prompt_token_ids] | ||||||||||||||||||||||||||
| if self._uses_harmony_tokenization(request): | ||||||||||||||||||||||||||
| return self._tokenize_harmony_chat(request) | ||||||||||||||||||||||||||
| tokenizer = self._get_tokenizer(request.model) | ||||||||||||||||||||||||||
| # Forward tool schemas and chat-template flags so router hashes use | ||||||||||||||||||||||||||
| # the same rendered prompt as the worker-side tokenizer. | ||||||||||||||||||||||||||
| chat_template_kwargs = dict(request.chat_template_kwargs or {}) | ||||||||||||||||||||||||||
| chat_template_kwargs["tools"] = self._tool_dicts(request) | ||||||||||||||||||||||||||
| result = tokenizer.apply_chat_template( | ||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||
| msg if isinstance(msg, dict) else dict(msg) | ||||||||||||||||||||||||||
|
|
@@ -719,14 +781,13 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: | |||||||||||||||||||||||||
| add_generation_prompt=request.add_generation_prompt, | ||||||||||||||||||||||||||
| tokenize=True, | ||||||||||||||||||||||||||
| return_dict=False, | ||||||||||||||||||||||||||
| **chat_template_kwargs, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| # Some custom tokenizers (e.g. DeepseekV32Tokenizer) return a | ||||||||||||||||||||||||||
| # string from apply_chat_template even with tokenize=True. | ||||||||||||||||||||||||||
| # Encode to token IDs if needed. | ||||||||||||||||||||||||||
| if isinstance(result, str): | ||||||||||||||||||||||||||
| result = tokenizer.encode(result, add_special_tokens=False) | ||||||||||||||||||||||||||
| # Set prompt_token_ids so the worker server skips re-tokenization | ||||||||||||||||||||||||||
| request.prompt_token_ids = result | ||||||||||||||||||||||||||
| return [result] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Handle CompletionRequest (has prompt) | ||||||||||||||||||||||||||
|
|
@@ -742,10 +803,6 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: | |||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| tokenizer = self._get_tokenizer(request.model) | ||||||||||||||||||||||||||
| token_lists = [tokenizer(prompt)["input_ids"] for prompt in prompts] | ||||||||||||||||||||||||||
| # Replace string prompts with token IDs so the worker server | ||||||||||||||||||||||||||
| # skips re-tokenization | ||||||||||||||||||||||||||
| request.prompt = (token_lists | ||||||||||||||||||||||||||
| if len(token_lists) > 1 else token_lists[0]) | ||||||||||||||||||||||||||
| return token_lists | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def _compute_block_hashes(self, | ||||||||||||||||||||||||||
|
|
@@ -799,10 +856,12 @@ def __init__(self, | |||||||||||||||||||||||||
| max_batch_size: int = 64, | ||||||||||||||||||||||||||
| tokens_per_block: int = 32, | ||||||||||||||||||||||||||
| custom_tokenizer: Optional[str] = None, | ||||||||||||||||||||||||||
| use_harmony: Optional[bool] = None, | ||||||||||||||||||||||||||
| **kwargs): | ||||||||||||||||||||||||||
| super().__init__(server_role, servers, metadata_server_cfg, | ||||||||||||||||||||||||||
| metadata_server, **kwargs) | ||||||||||||||||||||||||||
| self._init_block_hashing(tokens_per_block, custom_tokenizer) | ||||||||||||||||||||||||||
| self._init_block_hashing(tokens_per_block, custom_tokenizer, | ||||||||||||||||||||||||||
| use_harmony) | ||||||||||||||||||||||||||
| self._init_load_balancing(servers, use_tokens) | ||||||||||||||||||||||||||
| # TODO: use max_num_tokens? per server? | ||||||||||||||||||||||||||
| self._max_batch_size = max_batch_size | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Use the same draft-token resolver before scheduling and transmission-complete setup.
Line 2630 treats
None/empty ctx drafts as zero scheduled draft tokens, but Line 4478 later turns that same no-draft case intomax_total_draft_tokensdummy drafts whenmodel_engine.enable_spec_decodeis true. That can undercount batch capacity before the request prepares more decode tokens. Move this sync after the current spec-decode decision and share the same resolver with_prepare_disagg_gen_transmission_complete.🤖 Prompt for AI Agents