@@ -63,7 +63,8 @@ def __init__(
6363 self .llm_mode = llm_mode
6464 self .max_llm_retries = max_llm_retries
6565
66- def get_llm_inference_fn (self , sampling_params : dict = {}) -> Callable : # noqa: C901
66+
67+ def get_llm_inference_fn_sync (self , sampling_params : dict = {}) -> Callable : # noqa: C901
6768
6869 def llm_chat_verl (
6970 messages : List [Dict [str , str ]],
@@ -266,6 +267,206 @@ async def main():
266267
267268
268269
270+ def get_llm_inference_fn_async (self , sampling_params : dict = {}) -> Callable : # noqa: C901
271+
272+ async def llm_chat_verl (
273+ messages : List [Dict [str , str ]],
274+ custom_sampling_params : dict = {},
275+ tools = [],
276+ request_id : str = "" ,
277+ ) -> dict :
278+ request_id = uuid .uuid4 ().hex
279+
280+ updated_sampling_params = {}
281+ if sampling_params :
282+ updated_sampling_params .update (sampling_params )
283+ if custom_sampling_params :
284+ updated_sampling_params .update (custom_sampling_params )
285+
286+ input_messages = copy .deepcopy (messages )
287+ prompt_text = ajet_apply_chat_template (
288+ tokenizer = self .tokenizer ,
289+ conversation = input_messages ,
290+ tools = tools ,
291+ add_generation_prompt = True ,
292+ tokenize = False ,
293+ )
294+ prompt_ids = self .tokenizer (prompt_text )["input_ids" ]
295+
296+ if self .config .ajet .execute_test :
297+ _test_if_test_mode ("prompt_text" , prompt_text , self .config )
298+
299+ final_res = await self .async_rollout_manager .generate (
300+ request_id = request_id ,
301+ prompt_ids = prompt_ids ,
302+ sampling_params = updated_sampling_params ,
303+ )
304+
305+ if self .config .ajet .rollout .name == "vllm" :
306+ final_res : VerlVllmRequestOutput
307+ token_array = final_res .outputs [0 ].token_ids
308+ logprob_array = final_res .outputs [0 ].logprobs
309+ elif self .config .ajet .rollout .name == "sglang" :
310+ token_array = final_res
311+
312+ decoded_text = self .tokenizer .decode (token_array ) # type: ignore
313+ if self .config .ajet .execute_test :
314+ decoded_text = _mock_if_test_mode ("mock_decoded_text" , decoded_text , self .config )
315+
316+ if decoded_text .endswith ("<|im_end|>" ):
317+ decoded_text = decoded_text [: - len ("<|im_end|>" )]
318+
319+ # if tool call
320+ tool_calls = None
321+ if (
322+ ("<tool_call>" in decoded_text )
323+ and ("</tool_call>" in decoded_text )
324+ and (not self .config .ajet .rollout .force_disable_toolcalls )
325+ ):
326+ tool_parser = Hermes2ProToolParser (self .tokenizer )
327+ parsed_tool_calls = tool_parser .extract_tool_calls (decoded_text , None ) # type: ignore
328+ parsed_tool_calls = parsed_tool_calls .model_dump ()
329+ if self .config .ajet .execute_test :
330+ _test_if_test_mode (
331+ "parsed_tool_calls" , parsed_tool_calls ["tool_calls" ], self .config
332+ )
333+ model_called = parsed_tool_calls ["tools_called" ]
334+ if model_called :
335+ tool_calls = parsed_tool_calls ["tool_calls" ]
336+ is_bad_toolcall = False
337+ for i in range (len (tool_calls )):
338+ if "function" in tool_calls [i ] and "arguments" in tool_calls [i ]["function" ]:
339+ expect_dict = json .loads (tool_calls [i ]["function" ]["arguments" ])
340+ if not isinstance (expect_dict , dict ):
341+ is_bad_toolcall = True
342+ if is_bad_toolcall :
343+ tool_calls = None
344+ decoded_text = decoded_text
345+ else :
346+ decoded_text = parsed_tool_calls ["content" ]
347+ if decoded_text is None :
348+ decoded_text = ""
349+
350+ return {
351+ "role" : "assistant" ,
352+ "request_id" : request_id ,
353+ "content" : decoded_text ,
354+ "tool_calls" : tool_calls ,
355+ "tokens" : [
356+ TokenAndProb (
357+ token_id = token_id ,
358+ logprob = logprob [token_id ].logprob , # Warning: vllm logprob does not participant training (not reliable enough), for log only.
359+ decoded_string = logprob [token_id ].decoded_token ,
360+ )
361+ for token_id , logprob in zip (token_array , logprob_array ) # type: ignore
362+ ],
363+ }
364+
365+
366+ async def llm_chat_remote (
367+ messages : List [Dict [str , str ]],
368+ custom_sampling_params : dict = {},
369+ tools = [],
370+ request_id : str = "" ,
371+ ) -> dict :
372+ updated_sampling_params = {}
373+ if sampling_params :
374+ updated_sampling_params .update (sampling_params )
375+ if custom_sampling_params :
376+ updated_sampling_params .update (custom_sampling_params )
377+ updated_sampling_params .update ({"logprobs" : 1 , "return_tokens_as_token_ids" : True })
378+ input_messages = copy .deepcopy (messages )
379+ for i in range (self .max_llm_retries ):
380+ try :
381+ # this function is defined in `ajet/backbone/main_vllm.py`
382+ output_message = await self .async_rollout_manager .submit_chat_completions_async (
383+ messages = input_messages ,
384+ sampling_params = updated_sampling_params ,
385+ tools = tools ,
386+ request_id = request_id ,
387+ )
388+ break
389+ except Exception as e :
390+ logger .bind (exception = True ).exception (f"rollout_server.{ i } error: { e .args } " )
391+ time .sleep (i + 1 )
392+ return output_message [- 1 ] # type: ignore
393+
394+
395+ async def llm_chat_trinity (
396+ messages : List [Dict [str , str ]],
397+ custom_sampling_params : dict = {},
398+ tools = [],
399+ request_id : str = "" ,
400+ ) -> dict :
401+ async def main ():
402+ updated_sampling_params = {}
403+ if sampling_params :
404+ updated_sampling_params .update (sampling_params )
405+ if custom_sampling_params :
406+ updated_sampling_params .update (custom_sampling_params )
407+ updated_sampling_params .pop ("min_tokens" )
408+
409+ if tools :
410+ response = await self .async_rollout_manager .chat .completions .create (
411+ model = self .async_rollout_manager .model_path ,
412+ messages = messages ,
413+ logprobs = True ,
414+ tools = tools ,
415+ top_logprobs = 0 ,
416+ ** updated_sampling_params ,
417+ )
418+ else :
419+ response = await self .async_rollout_manager .chat .completions .create (
420+ model = self .async_rollout_manager .model_path ,
421+ messages = messages ,
422+ logprobs = True ,
423+ top_logprobs = 0 ,
424+ ** updated_sampling_params ,
425+ )
426+ return response
427+
428+ response = await main ()
429+ prompt_text = self .tokenizer .decode (response .model_extra ["prompt_token_ids" ])
430+ prompt_token_ids = response .model_extra ["prompt_token_ids" ]
431+ content = response .choices [0 ].message .content
432+ message = response .choices [0 ].message .model_dump (exclude_unset = True , exclude_none = True )
433+
434+ if content is None :
435+ content = ""
436+
437+ if ("<tool_call>" in content ) and (not message .get ("tool_calls" , None )):
438+ # logger.bind(exception=True).exception(f"Bad toolcall discovered \n\nprompt_text:\n{prompt_text}\n\nrepsonse:\n{content}")
439+ logger .warning (f"Bad toolcall discovered: { content } " )
440+
441+ return {
442+ "role" : "assistant" ,
443+ "request_id" : response .id ,
444+ "content" : content ,
445+ "prompt_text" : prompt_text ,
446+ "prompt_token_ids" : prompt_token_ids ,
447+ "tool_calls" : message .get ("tool_calls" , []),
448+ "tokens" : [
449+ TokenAndProb (
450+ token_id = token ,
451+ logprob = tokenlogprob .logprob , # Warning: vllm logprob does not participant training, for log only.
452+ decoded_string = tokenlogprob .token ,
453+ )
454+ for tokenlogprob , token in zip (
455+ response .choices [0 ].logprobs .content ,
456+ response .choices [0 ].token_ids ,
457+ )
458+ ],
459+ }
460+
461+ if self .llm_mode == "remote" :
462+ return llm_chat_remote
463+ if self .llm_mode == "trinity" :
464+ return llm_chat_trinity
465+ else :
466+ return llm_chat_verl
467+
468+
469+
269470
270471# ----------------------------------------------------------------------------------------------
271472# ------------------------ call async llm with context tracker (OpenAI) ------------------------
@@ -334,12 +535,15 @@ async def run_infer(
334535 # otherwise, for abnormal output, can still proceed, but we do not track output anymore
335536
336537 # run llm inference ✨
337- llm_output = await asyncio .wait_for (
338- asyncio .to_thread (
339- self .llm_inference_fn , converted_message , custom_sampling_params , tools
340- ),
341- timeout = 1800 ,
342- )
538+ # if sync:
539+ # llm_output = await asyncio.wait_for(
540+ # asyncio.to_thread(
541+ # self.llm_inference_fn, converted_message, custom_sampling_params, tools
542+ # ),
543+ # timeout=1800,
544+ # )
545+ llm_output = await asyncio .wait_for (self .llm_inference_fn (converted_message , custom_sampling_params , tools ), timeout = 1800 )
546+
343547
344548 # begin context tracking
345549 self .context_tracker .step_track (llm_output , context_safe , converted_message , tools , timeline_uuid = timeline_uuid )
0 commit comments