Skip to content

Commit fc0a680

Browse files
committed
maybe stable
1 parent 855cb2f commit fc0a680

File tree

6 files changed

+345
-47
lines changed

6 files changed

+345
-47
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,5 @@ saved_checkpoints
150150
data
151151
datasets
152152
tutorial2
153-
site
153+
site
154+
dump.rdb

ajet/backbone/main_vllm.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from types import SimpleNamespace
44

55
import hydra
6-
from openai import OpenAI
6+
from openai import AsyncOpenAI, OpenAI
77

88
from ajet.backbone.warm_up import warm_up_process
99
from ajet.task_rollout.native_parallel_worker import VerlRolloutManager
@@ -88,6 +88,53 @@ def submit_chat_completions(self, messages, sampling_params, request_id, tools=[
8888
)
8989
return messages
9090

91+
async def submit_chat_completions_async(self, messages, sampling_params, request_id, tools=[]):
92+
client = AsyncOpenAI(
93+
base_url=self.url,
94+
api_key="token-abc123",
95+
)
96+
sampling_params = dict(
97+
n=1,
98+
max_completion_tokens=self.config.ajet.rollout.max_response_length_in_one_turn,
99+
)
100+
sampling_params["temperature"] = self.config.ajet.rollout.val_kwargs.temperature
101+
sampling_params["top_k"] = self.config.ajet.rollout.val_kwargs.top_k
102+
sampling_params["top_p"] = self.config.ajet.rollout.val_kwargs.top_p
103+
104+
sampling_params.update({"logprobs": 1, "return_tokens_as_token_ids": True})
105+
106+
if tools:
107+
completion = await client.chat.completions.create(
108+
model=self.config.ajet.model.path,
109+
messages=messages,
110+
tools=tools,
111+
extra_body=sampling_params,
112+
)
113+
else:
114+
completion = await client.chat.completions.create(
115+
model=self.config.ajet.model.path,
116+
messages=messages,
117+
extra_body=sampling_params,
118+
)
119+
120+
message = completion.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)
121+
122+
# sometimes tool use message has no content field
123+
if "content" not in message:
124+
message["content"] = ""
125+
126+
messages.append(
127+
{
128+
"role": message["role"],
129+
"request_id": completion.id,
130+
"content": message["content"],
131+
"tool_calls": message.get("tool_calls", None),
132+
"tokens": [
133+
TokenAndProbVllmDebug(t) for t in completion.choices[0].logprobs.content # type: ignore
134+
],
135+
}
136+
)
137+
return messages
91138

92139
def run(config):
93140
from ajet.task_reader import RouterTaskReader

ajet/task_rollout/async_llm_bridge.py

Lines changed: 211 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def __init__(
6363
self.llm_mode = llm_mode
6464
self.max_llm_retries = max_llm_retries
6565

66-
def get_llm_inference_fn(self, sampling_params: dict = {}) -> Callable: # noqa: C901
66+
67+
def get_llm_inference_fn_sync(self, sampling_params: dict = {}) -> Callable: # noqa: C901
6768

6869
def llm_chat_verl(
6970
messages: List[Dict[str, str]],
@@ -266,6 +267,206 @@ async def main():
266267

267268

268269

270+
def get_llm_inference_fn_async(self, sampling_params: dict = {}) -> Callable: # noqa: C901
271+
272+
async def llm_chat_verl(
273+
messages: List[Dict[str, str]],
274+
custom_sampling_params: dict = {},
275+
tools=[],
276+
request_id: str = "",
277+
) -> dict:
278+
request_id = uuid.uuid4().hex
279+
280+
updated_sampling_params = {}
281+
if sampling_params:
282+
updated_sampling_params.update(sampling_params)
283+
if custom_sampling_params:
284+
updated_sampling_params.update(custom_sampling_params)
285+
286+
input_messages = copy.deepcopy(messages)
287+
prompt_text = ajet_apply_chat_template(
288+
tokenizer=self.tokenizer,
289+
conversation=input_messages,
290+
tools=tools,
291+
add_generation_prompt=True,
292+
tokenize=False,
293+
)
294+
prompt_ids = self.tokenizer(prompt_text)["input_ids"]
295+
296+
if self.config.ajet.execute_test:
297+
_test_if_test_mode("prompt_text", prompt_text, self.config)
298+
299+
final_res = await self.async_rollout_manager.generate(
300+
request_id=request_id,
301+
prompt_ids=prompt_ids,
302+
sampling_params=updated_sampling_params,
303+
)
304+
305+
if self.config.ajet.rollout.name == "vllm":
306+
final_res: VerlVllmRequestOutput
307+
token_array = final_res.outputs[0].token_ids
308+
logprob_array = final_res.outputs[0].logprobs
309+
elif self.config.ajet.rollout.name == "sglang":
310+
token_array = final_res
311+
312+
decoded_text = self.tokenizer.decode(token_array) # type: ignore
313+
if self.config.ajet.execute_test:
314+
decoded_text = _mock_if_test_mode("mock_decoded_text", decoded_text, self.config)
315+
316+
if decoded_text.endswith("<|im_end|>"):
317+
decoded_text = decoded_text[: -len("<|im_end|>")]
318+
319+
# if tool call
320+
tool_calls = None
321+
if (
322+
("<tool_call>" in decoded_text)
323+
and ("</tool_call>" in decoded_text)
324+
and (not self.config.ajet.rollout.force_disable_toolcalls)
325+
):
326+
tool_parser = Hermes2ProToolParser(self.tokenizer)
327+
parsed_tool_calls = tool_parser.extract_tool_calls(decoded_text, None) # type: ignore
328+
parsed_tool_calls = parsed_tool_calls.model_dump()
329+
if self.config.ajet.execute_test:
330+
_test_if_test_mode(
331+
"parsed_tool_calls", parsed_tool_calls["tool_calls"], self.config
332+
)
333+
model_called = parsed_tool_calls["tools_called"]
334+
if model_called:
335+
tool_calls = parsed_tool_calls["tool_calls"]
336+
is_bad_toolcall = False
337+
for i in range(len(tool_calls)):
338+
if "function" in tool_calls[i] and "arguments" in tool_calls[i]["function"]:
339+
expect_dict = json.loads(tool_calls[i]["function"]["arguments"])
340+
if not isinstance(expect_dict, dict):
341+
is_bad_toolcall = True
342+
if is_bad_toolcall:
343+
tool_calls = None
344+
decoded_text = decoded_text
345+
else:
346+
decoded_text = parsed_tool_calls["content"]
347+
if decoded_text is None:
348+
decoded_text = ""
349+
350+
return {
351+
"role": "assistant",
352+
"request_id": request_id,
353+
"content": decoded_text,
354+
"tool_calls": tool_calls,
355+
"tokens": [
356+
TokenAndProb(
357+
token_id=token_id,
358+
logprob=logprob[token_id].logprob, # Warning: vllm logprob does not participant training (not reliable enough), for log only.
359+
decoded_string=logprob[token_id].decoded_token,
360+
)
361+
for token_id, logprob in zip(token_array, logprob_array) # type: ignore
362+
],
363+
}
364+
365+
366+
async def llm_chat_remote(
367+
messages: List[Dict[str, str]],
368+
custom_sampling_params: dict = {},
369+
tools=[],
370+
request_id: str = "",
371+
) -> dict:
372+
updated_sampling_params = {}
373+
if sampling_params:
374+
updated_sampling_params.update(sampling_params)
375+
if custom_sampling_params:
376+
updated_sampling_params.update(custom_sampling_params)
377+
updated_sampling_params.update({"logprobs": 1, "return_tokens_as_token_ids": True})
378+
input_messages = copy.deepcopy(messages)
379+
for i in range(self.max_llm_retries):
380+
try:
381+
# this function is defined in `ajet/backbone/main_vllm.py`
382+
output_message = await self.async_rollout_manager.submit_chat_completions_async(
383+
messages=input_messages,
384+
sampling_params=updated_sampling_params,
385+
tools=tools,
386+
request_id=request_id,
387+
)
388+
break
389+
except Exception as e:
390+
logger.bind(exception=True).exception(f"rollout_server.{i} error: {e.args}")
391+
time.sleep(i + 1)
392+
return output_message[-1] # type: ignore
393+
394+
395+
async def llm_chat_trinity(
396+
messages: List[Dict[str, str]],
397+
custom_sampling_params: dict = {},
398+
tools=[],
399+
request_id: str = "",
400+
) -> dict:
401+
async def main():
402+
updated_sampling_params = {}
403+
if sampling_params:
404+
updated_sampling_params.update(sampling_params)
405+
if custom_sampling_params:
406+
updated_sampling_params.update(custom_sampling_params)
407+
updated_sampling_params.pop("min_tokens")
408+
409+
if tools:
410+
response = await self.async_rollout_manager.chat.completions.create(
411+
model=self.async_rollout_manager.model_path,
412+
messages=messages,
413+
logprobs=True,
414+
tools=tools,
415+
top_logprobs=0,
416+
**updated_sampling_params,
417+
)
418+
else:
419+
response = await self.async_rollout_manager.chat.completions.create(
420+
model=self.async_rollout_manager.model_path,
421+
messages=messages,
422+
logprobs=True,
423+
top_logprobs=0,
424+
**updated_sampling_params,
425+
)
426+
return response
427+
428+
response = await main()
429+
prompt_text = self.tokenizer.decode(response.model_extra["prompt_token_ids"])
430+
prompt_token_ids = response.model_extra["prompt_token_ids"]
431+
content = response.choices[0].message.content
432+
message = response.choices[0].message.model_dump(exclude_unset=True, exclude_none=True)
433+
434+
if content is None:
435+
content = ""
436+
437+
if ("<tool_call>" in content) and (not message.get("tool_calls", None)):
438+
# logger.bind(exception=True).exception(f"Bad toolcall discovered \n\nprompt_text:\n{prompt_text}\n\nrepsonse:\n{content}")
439+
logger.warning(f"Bad toolcall discovered: {content}")
440+
441+
return {
442+
"role": "assistant",
443+
"request_id": response.id,
444+
"content": content,
445+
"prompt_text": prompt_text,
446+
"prompt_token_ids": prompt_token_ids,
447+
"tool_calls": message.get("tool_calls", []),
448+
"tokens": [
449+
TokenAndProb(
450+
token_id=token,
451+
logprob=tokenlogprob.logprob, # Warning: vllm logprob does not participant training, for log only.
452+
decoded_string=tokenlogprob.token,
453+
)
454+
for tokenlogprob, token in zip(
455+
response.choices[0].logprobs.content,
456+
response.choices[0].token_ids,
457+
)
458+
],
459+
}
460+
461+
if self.llm_mode == "remote":
462+
return llm_chat_remote
463+
if self.llm_mode == "trinity":
464+
return llm_chat_trinity
465+
else:
466+
return llm_chat_verl
467+
468+
469+
269470

270471
# ----------------------------------------------------------------------------------------------
271472
# ------------------------ call async llm with context tracker (OpenAI) ------------------------
@@ -334,12 +535,15 @@ async def run_infer(
334535
# otherwise, for abnormal output, can still proceed, but we do not track output anymore
335536

336537
# run llm inference ✨
337-
llm_output = await asyncio.wait_for(
338-
asyncio.to_thread(
339-
self.llm_inference_fn, converted_message, custom_sampling_params, tools
340-
),
341-
timeout=1800,
342-
)
538+
# if sync:
539+
# llm_output = await asyncio.wait_for(
540+
# asyncio.to_thread(
541+
# self.llm_inference_fn, converted_message, custom_sampling_params, tools
542+
# ),
543+
# timeout=1800,
544+
# )
545+
llm_output = await asyncio.wait_for(self.llm_inference_fn(converted_message, custom_sampling_params, tools), timeout=1800)
546+
343547

344548
# begin context tracking
345549
self.context_tracker.step_track(llm_output, context_safe, converted_message, tools, timeline_uuid=timeline_uuid)

ajet/task_rollout/single_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def rollout_env_worker(
8484
(with validation overrides), and robust retry on transient failures.
8585
"""
8686
sampling_params = get_sample_params(mode, self.config)
87-
llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn(
87+
llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_async(
8888
sampling_params=sampling_params
8989
)
9090

0 commit comments

Comments
 (0)