1313# limitations under the License.
1414
1515import asyncio
16+ import json
1617import os
1718import time
1819from abc import ABC , abstractmethod
@@ -676,16 +677,20 @@ class BlockHashMixin:
676677
677678 def _init_block_hashing (self ,
678679 tokens_per_block : int = 32 ,
679- custom_tokenizer : Optional [str ] = None ):
680+ custom_tokenizer : Optional [str ] = None ,
681+ use_harmony : Optional [bool ] = None ) -> None :
680682 env_tokens_per_block = os .environ .get (
681683 "TRTLLM_KVCACHE_AWARE_ROUTER_HASH_TOKENS_PER_BLOCK" )
682684 if env_tokens_per_block is not None :
683685 tokens_per_block = int (env_tokens_per_block )
684686 self ._tokens_per_block = tokens_per_block
685687 self ._tokenizers : dict = {}
688+ self ._model_types : dict [str , Optional [str ]] = {}
686689 self ._custom_tokenizer = custom_tokenizer
690+ self ._use_harmony = use_harmony
687691 logger .info (f"BlockHashMixin: tokens_per_block={ self ._tokens_per_block } "
688- f", custom_tokenizer={ self ._custom_tokenizer } " )
692+ f", custom_tokenizer={ self ._custom_tokenizer } "
693+ f", use_harmony={ self ._use_harmony } " )
689694
690695 def _get_tokenizer (self , model : str ):
691696 if model not in self ._tokenizers :
@@ -705,12 +710,69 @@ def _get_tokenizer(self, model: str):
705710 model , trust_remote_code = True ).tokenizer
706711 return self ._tokenizers [model ]
707712
713+ def _get_model_type (self , model : str ) -> Optional [str ]:
714+ if model not in self ._model_types :
715+ model_type = None
716+ normalized_model = model .lower ().replace ("_" , "-" )
717+ if "gpt-oss" in normalized_model or "gptoss" in normalized_model :
718+ model_type = "gpt_oss"
719+ else :
720+ config_path = os .path .join (model , "config.json" )
721+ if os .path .isfile (config_path ):
722+ try :
723+ with open (config_path , encoding = "utf-8" ) as config_file :
724+ config = json .load (config_file )
725+ if isinstance (config , dict ):
726+ raw_model_type = config .get ("model_type" )
727+ if isinstance (raw_model_type , str ):
728+ model_type = raw_model_type
729+ except (OSError , json .JSONDecodeError ) as e :
730+ logger .debug (
731+ "BlockHashMixin: failed to read model config for "
732+ f"{ model } : { e } " )
733+ self ._model_types [model ] = model_type
734+ return self ._model_types [model ]
735+
736+ def _uses_harmony_tokenization (self ,
737+ request : ChatCompletionRequest ) -> bool :
738+ if self ._use_harmony is not None :
739+ return self ._use_harmony
740+ return self ._get_model_type (request .model ) == "gpt_oss"
741+
742+ @staticmethod
743+ def _tool_dicts (
744+ request : ChatCompletionRequest
745+ ) -> Optional [list [dict [str , object ]]]:
746+ if request .tools is None :
747+ return None
748+ return [tool .model_dump () for tool in request .tools ]
749+
750+ def _tokenize_harmony_chat (
751+ self , request : ChatCompletionRequest ) -> list [list [int ]]:
752+ from tensorrt_llm .serve import harmony_adapter
753+
754+ tools = self ._tool_dicts (request ) if request .tools else None
755+ result = harmony_adapter .get_harmony_adapter ().openai_to_harmony_tokens (
756+ request .messages ,
757+ tools ,
758+ reasoning_effort = harmony_adapter .maybe_transform_reasoning_effort (
759+ request .reasoning_effort ),
760+ tool_choice = request .tool_choice ,
761+ )
762+ return [result ]
763+
708764 def _tokenize (self , request : OpenAIRequest ) -> list [list [int ]]:
709765 # Handle ChatCompletionRequest (has messages, not prompt)
710766 if isinstance (request , ChatCompletionRequest ):
711767 if request .prompt_token_ids is not None :
712768 return [request .prompt_token_ids ]
769+ if self ._uses_harmony_tokenization (request ):
770+ return self ._tokenize_harmony_chat (request )
713771 tokenizer = self ._get_tokenizer (request .model )
772+ # Forward tool schemas and chat-template flags so router hashes use
773+ # the same rendered prompt as the worker-side tokenizer.
774+ chat_template_kwargs = dict (request .chat_template_kwargs or {})
775+ chat_template_kwargs ["tools" ] = self ._tool_dicts (request )
714776 result = tokenizer .apply_chat_template (
715777 [
716778 msg if isinstance (msg , dict ) else dict (msg )
@@ -719,14 +781,13 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
719781 add_generation_prompt = request .add_generation_prompt ,
720782 tokenize = True ,
721783 return_dict = False ,
784+ ** chat_template_kwargs ,
722785 )
723786 # Some custom tokenizers (e.g. DeepseekV32Tokenizer) return a
724787 # string from apply_chat_template even with tokenize=True.
725788 # Encode to token IDs if needed.
726789 if isinstance (result , str ):
727790 result = tokenizer .encode (result , add_special_tokens = False )
728- # Set prompt_token_ids so the worker server skips re-tokenization
729- request .prompt_token_ids = result
730791 return [result ]
731792
732793 # Handle CompletionRequest (has prompt)
@@ -742,10 +803,6 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]:
742803
743804 tokenizer = self ._get_tokenizer (request .model )
744805 token_lists = [tokenizer (prompt )["input_ids" ] for prompt in prompts ]
745- # Replace string prompts with token IDs so the worker server
746- # skips re-tokenization
747- request .prompt = (token_lists
748- if len (token_lists ) > 1 else token_lists [0 ])
749806 return token_lists
750807
751808 def _compute_block_hashes (self ,
@@ -799,10 +856,12 @@ def __init__(self,
799856 max_batch_size : int = 64 ,
800857 tokens_per_block : int = 32 ,
801858 custom_tokenizer : Optional [str ] = None ,
859+ use_harmony : Optional [bool ] = None ,
802860 ** kwargs ):
803861 super ().__init__ (server_role , servers , metadata_server_cfg ,
804862 metadata_server , ** kwargs )
805- self ._init_block_hashing (tokens_per_block , custom_tokenizer )
863+ self ._init_block_hashing (tokens_per_block , custom_tokenizer ,
864+ use_harmony )
806865 self ._init_load_balancing (servers , use_tokens )
807866 # TODO: use max_num_tokens? per server?
808867 self ._max_batch_size = max_batch_size
0 commit comments