2020import copy
2121import io
2222import json
23+ import logging
2324import os
2425import sys
2526import time
@@ -56,6 +57,9 @@ class RequestFuncInput:
5657 response_format : Optional [dict ] = None
5758 random_flag : bool = False
5859 json_data : Optional [dict ] = None
60+ prompt_token_ids : Optional [list ] = None
61+ tokenizer_model : str = None
62+ tokenizer_path : str = None
5963
6064
6165@dataclass
@@ -81,6 +85,7 @@ class RequestFuncOutput:
8185 error : str = ""
8286 metrics : dict = field (default_factory = dict )
8387 tool_calls : list = field (default_factory = list )
88+ output_ids : list = field (default_factory = list )
8489
8590
8691@dataclass
@@ -178,6 +183,49 @@ def metrics_summary(metrics, token_timestamps):
178183 return summary
179184
180185
186+ def load_tokenizer (model , actor_tokenizer_path ):
187+ """加载tokenizer"""
188+ from ernie_tokenizer import Ernie5Tokenizer , ErnieBotTokenizer
189+ from paddleformers .transformers import AutoTokenizer
190+
191+ from fastdeploy .input .ernie4_5_tokenizer import Ernie4_5Tokenizer
192+
193+ vocab_file_names = ["tokenizer.model" , "spm.model" , "ernie_token_100k.model" ]
194+
195+ try :
196+ if model == "eb" :
197+ for i in range (len (vocab_file_names )):
198+ if os .path .exists (os .path .join (actor_tokenizer_path , vocab_file_names [i ])):
199+ ErnieBotTokenizer .resource_files_names ["vocab_file" ] = vocab_file_names [i ]
200+ break
201+ tokenizer = ErnieBotTokenizer .from_pretrained (actor_tokenizer_path )
202+ elif model == "eb_mm" :
203+ for vocab_file in vocab_file_names :
204+ full_path = os .path .join (actor_tokenizer_path , vocab_file )
205+ if os .path .exists (full_path ):
206+ Ernie4_5Tokenizer .resource_files_names ["vocab_file" ] = vocab_file
207+ # for i in range(len(vocab_file_names)):
208+ # if os.path.exists(os.path.join(actor_tokenizer_path, vocab_file_names[i])):
209+ # Ernie45Tokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
210+ # break
211+ tokenizer = Ernie4_5Tokenizer .from_pretrained (actor_tokenizer_path )
212+ # tokenizer.ignored_index = -100
213+ elif model == "eb5" :
214+ for i in range (len (vocab_file_names )):
215+ if os .path .exists (os .path .join (actor_tokenizer_path , vocab_file_names [i ])):
216+ Ernie5Tokenizer .resource_files_names ["vocab_file" ] = vocab_file_names [i ]
217+ break
218+ tokenizer = Ernie5Tokenizer .from_pretrained (actor_tokenizer_path )
219+ else :
220+ print ("tokenizer: AUTO" )
221+ tokenizer = AutoTokenizer .from_pretrained (actor_tokenizer_path , padding_side = "left" , use_fast = True )
222+ except Exception as e :
223+ tokenizer = None
224+ logging .warning (f"Load tokenizer error: { e } " )
225+
226+ return tokenizer
227+
228+
181229async def async_request_eb_openai_chat_completions (
182230 request_func_input : RequestFuncInput ,
183231 pbar : Optional [tqdm ] = None ,
@@ -221,6 +269,14 @@ async def async_request_eb_openai_chat_completions(
221269 if request_func_input .response_format :
222270 payload ["response_format" ] = request_func_input .response_format
223271
272+ # 支持传入prompt_token_ids
273+ if request_func_input .prompt_token_ids :
274+ # 不走messages
275+ payload ["messages" ] = [{"role" : "user" , "content" : [{"type" : "text" , "text" : "" }]}]
276+ payload ["prompt_token_ids" ] = request_func_input .prompt_token_ids
277+ payload ["return_token_ids" ] = True
278+ # print("use_token_ids:", payload)
279+
224280 # 超参由yaml传入
225281 payload .update (request_func_input .hyper_parameters )
226282
@@ -298,6 +354,7 @@ async def async_request_eb_openai_chat_completions(
298354 content = choices [0 ]["delta" ].get ("content" )
299355 reason_content = choices [0 ]["delta" ].get ("reasoning_content" )
300356 tool_calls = choices [0 ]["delta" ].get ("tool_calls" )
357+ completion_token_ids = choices [0 ]["delta" ].get ("completion_token_ids" , [])
301358 if tool_calls :
302359 for tc in tool_calls :
303360 idx = tc .get ("index" , 0 )
@@ -343,6 +400,8 @@ async def async_request_eb_openai_chat_completions(
343400
344401 output .generated_text += content or ""
345402 output .reasoning_content += reason_content or ""
403+ if completion_token_ids :
404+ output .output_ids .extend (completion_token_ids )
346405 # print(f"####content:{data}")
347406 output .arrival_time .append (choices [0 ].get ("arrival_time" , timestamp ))
348407 elif usage := data .get ("usage" , {}):
@@ -487,6 +546,27 @@ async def async_request_eb_openai_chat_completions_multi_turn(
487546 print ("START" , request_func_input .no , "user对话轮数:" , user_count , flush = True )
488547 history = []
489548 prompt_no = 0
549+ max_prompt_len = (
550+ hyper .get ("max_prompt_len" ) if hyper .get ("max_prompt_len" ) is not None else json_data .get ("max_prompt_len" )
551+ )
552+ print ("max_prompt_len:" , max_prompt_len )
553+ input_ids_all = []
554+ # FD每轮 completion_token_ids
555+ output_ids = []
556+ use_token_ids = bool (request_func_input .tokenizer_model and request_func_input .tokenizer_path )
557+ tokenizer = None
558+
559+ if use_token_ids :
560+ print ("token ids 拼接模式" )
561+ enable_tools = False
562+ print ("tokenizer_model:" , request_func_input .tokenizer_model )
563+ print ("tokenizer_path:" , request_func_input .tokenizer_path )
564+ tokenizer = load_tokenizer (
565+ request_func_input .tokenizer_model ,
566+ request_func_input .tokenizer_path ,
567+ )
568+ else :
569+ print ("messages 明文拼接模式" )
490570
491571 # 只创建一次 session
492572 session_start = time .perf_counter ()
@@ -508,6 +588,44 @@ async def async_request_eb_openai_chat_completions_multi_turn(
508588 round_input = copy .deepcopy (request_func_input )
509589 round_input .history_QA = history
510590 round_input .no = f"{ round_input .no } _{ prompt_no } "
591+ if use_token_ids :
592+ if len (input_ids_all ) == 0 :
593+ # 拼接token_ids模式,首轮token_ids
594+ spliced_text = tokenizer .apply_chat_template (
595+ history ,
596+ tokenize = False ,
597+ split_special_tokens = False ,
598+ add_special_tokens = False ,
599+ )
600+ # 转换为token ids
601+ tokens = tokenizer .tokenize (spliced_text )
602+ prompt_token_ids = tokenizer .convert_tokens_to_ids (tokens )
603+ input_ids_all .extend (prompt_token_ids )
604+ round_input .prompt_token_ids = input_ids_all
605+ else :
606+ prompt_length = len (input_ids_all ) + len (output_ids )
607+ if max_prompt_len and prompt_length >= max_prompt_len :
608+ # 超长截断
609+ print (
610+ f"[SESSION STOP] { round_input .no } reach max_prompt_len={ max_prompt_len } , stop session"
611+ )
612+ break
613+ # 拼接token_ids模式,后续轮
614+ input_ids_all .extend (output_ids )
615+ user_prompt = message ["content" ]
616+ # 拼接user_prompt
617+ if round_input .tokenizer_model == "eb5" :
618+ # EB5模型
619+ user_prompt = (
620+ f"\n \n <|im_start|>user\n { user_prompt } <|im_end|>\n \n <|im_start|>assistant\n <think>\n "
621+ )
622+ else :
623+ # 0.3B模型,2 </s>,拼接时会被替换成100272 <|end_of_sentence|>
624+ input_ids_all [- 1 ] = 100272
625+ user_prompt = f"User: { user_prompt } \n Assistant: "
626+ prompt_token_ids = tokenizer .convert_tokens_to_ids (tokenizer .tokenize (user_prompt ))
627+ input_ids_all .extend (prompt_token_ids )
628+ round_input .prompt_token_ids = input_ids_all
511629 # 复用 session
512630 s0 = time .perf_counter ()
513631 output = await async_request_eb_openai_chat_completions (
@@ -536,6 +654,14 @@ async def async_request_eb_openai_chat_completions_multi_turn(
536654 input_tokens += output .prompt_tokens
537655 output_tokens += output .output_tokens
538656
657+ # 更新output_ids
658+ output_ids = output .output_ids
659+
660+ if max_prompt_len and input_tokens >= max_prompt_len :
661+ # 后验超长截断
662+ print (f"[SESSION STOP] { round_input .no } reach max_prompt_len={ max_prompt_len } , stop session" )
663+ break
664+
539665 if enable_tools :
540666 # 循环调用工具
541667 max_loop = json_data .get ("max_loop" , 10 )
@@ -643,7 +769,9 @@ async def async_request_eb_openai_chat_completions_multi_turn(
643769 output_tokens += output .output_tokens
644770 # 若session输入长度超过max_prompt_len,则停止session
645771 if max_prompt_len and input_tokens >= max_prompt_len :
646- print (f"[SESSION STOP] { prompt_no } reach max_prompt_len={ max_prompt_len } , stop session" )
772+ print (
773+ f"[SESSION STOP] { round_input .no } reach max_prompt_len={ max_prompt_len } , stop session"
774+ )
647775 session_end = time .perf_counter ()
648776 metrics = SessionMetrics (
649777 session_no = request_func_input .no ,
0 commit comments