Skip to content

Commit 6866109

Browse files
committed
imp raw http training
1 parent f74d1dd commit 6866109

File tree

12 files changed

+248
-55
lines changed

12 files changed

+248
-55
lines changed

ajet/backbone/main_vllm.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,21 @@
99
from ajet.task_rollout.native_parallel_worker import VerlRolloutManager
1010
from ajet.utils.launch_utils import set_loguru_default_color
1111
from ajet.schema.logprob import TokenAndProb
12+
from ajet.utils.core_env_vars import get_runtime_env
1213

1314
set_loguru_default_color()
1415

1516

1617
class TokenAndProbVllmDebug(TokenAndProb):
1718
def __init__(self, t):
1819
# ChatCompletionTokenLogprob(token='token_id:73594', bytes=[96, 96, 96], logprob=-1.9073468138230965e-06, top_logprobs=[])
19-
self.token_id = int(t.token.split("token_id:")[-1])
20-
self.logprob = t.logprob
20+
token_id = int(t.token.split("token_id:")[-1])
21+
logprob = t.logprob
2122
try:
22-
self.decoded_string = bytes(t.bytes).decode("utf-8")
23+
decoded_string = bytes(t.bytes).decode("utf-8")
2324
except Exception:
24-
self.decoded_string = "<cannot decode>" + str(t.bytes)
25+
decoded_string = "<cannot decode>" + str(t.bytes)
26+
super().__init__(token_id=token_id, logprob=logprob, decoded_string=decoded_string)
2527

2628

2729
class ChatCompletionScheduler:
@@ -87,6 +89,8 @@ def submit_chat_completions(self, messages, sampling_params, request_id, tools=[
8789

8890

8991
def run(config):
92+
from ajet.task_reader import RouterTaskReader
93+
9094
# --------- fast adjustment for debugging ---------
9195
warm_up_process(config)
9296
max_parallel = config.ajet.debug.debug_max_parallel
@@ -106,7 +110,6 @@ def run(config):
106110
tokenizer=async_rollout_manager.tokenizer,
107111
)
108112

109-
from ajet.task_reader import RouterTaskReader
110113

111114
task_reader = RouterTaskReader(
112115
config.ajet.task_reader.type,
@@ -132,6 +135,12 @@ def main(config):
132135
OmegaConf.resolve(config)
133136
print("*" * 20)
134137

138+
runtime_env = get_runtime_env()
139+
os.environ.update(runtime_env["env_vars"])
140+
if config.ajet.enable_experimental_reverse_proxy:
141+
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
142+
start_interchange_server()
143+
135144
def companion_launch():
136145
import torch
137146

ajet/backbone/trainer_verl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,11 @@ def fit(self): # noqa: C901
832832
progress_bar.update(1)
833833
self.global_steps += 1
834834

835+
# when enabled oai request interchange, we need to clear the cache from time to time
836+
if self.config.ajet.enable_experimental_reverse_proxy:
837+
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear
838+
ensure_dat_interchange_server_cache_clear()
839+
835840
if is_last_step:
836841
pprint(f"Final validation metrics: {last_val_metrics}")
837842
progress_bar.close()

ajet/context_tracker/multiagent_tracking.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,33 @@ def __init__(
6464
self.episode_uuid = episode_uuid
6565

6666

67-
def step_prepare(self, messages: List[dict], tools: List = [], timeline_uuid: str = ""):
67+
def preprocess_tools_field(self, tools: List[dict] = [], disable_toolcalls: bool = False):
68+
if disable_toolcalls:
69+
tools = []
70+
else:
71+
if tools is not None:
72+
# rerank tool parameters to improve compatibility
73+
for i in range(len(tools)):
74+
tools[i]["function"]["parameters"] = tools[i]["function"].pop("parameters")
75+
return tools
76+
77+
78+
def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_toolcalls: bool = False) -> List[ExtendedMessage]:
79+
"""Spawn a timeline from messages.
80+
81+
Args:
82+
messages: List of message dictionaries
83+
tools: List of tool dictionaries
84+
disable_toolcalls: Whether to disable tool calls
85+
86+
Returns:
87+
List of ExtendedMessage objects representing the timeline
88+
"""
6889
timeline = []
90+
6991
consider_roles = ["user", "assistant", "system", "tool"]
70-
disable_toolcalls = self.config.ajet.rollout.force_disable_toolcalls
7192
if disable_toolcalls:
7293
consider_roles.remove("tool")
73-
tools = []
74-
else:
75-
# rerank tool parameters to improve compatibility
76-
for i in range(len(tools)):
77-
tools[i]["function"]["parameters"] = tools[i]["function"].pop("parameters")
7894

7995
for i, msg in enumerate(messages):
8096
if (disable_toolcalls) and (not isinstance(msg["content"], str)):
@@ -132,6 +148,14 @@ def step_prepare(self, messages: List[dict], tools: List = [], timeline_uuid: st
132148
)
133149
]
134150

151+
return timeline
152+
153+
154+
def step_prepare(self, messages: List[dict], tools: List = [], timeline_uuid: str = ""):
155+
disable_toolcalls = self.config.ajet.rollout.force_disable_toolcalls
156+
tools = self.preprocess_tools_field(tools, disable_toolcalls=disable_toolcalls)
157+
timeline = self.step_spawn_timeline(messages, tools, disable_toolcalls)
158+
135159
# check token overflow
136160
converted_message = self.to_role_content(timeline)
137161
timeline = ExtendedMessage.check_and_merge_chained_tool_response(

ajet/schema/logprob.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
# from verl import DataProto
66

77

8-
class TokenAndProb:
9-
def __init__(self, token_id, logprob, decoded_string):
10-
self.token_id = token_id
11-
self.logprob = logprob
12-
self.decoded_string = decoded_string
8+
from pydantic import BaseModel
9+
10+
11+
class TokenAndProb(BaseModel):
12+
token_id: int
13+
logprob: float
14+
decoded_string: str

ajet/task_rollout/async_llm_bridge.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import uuid
66
from typing import Any, Callable, Dict, List, Literal, Type, Union
77

8-
from agentscope._utils._common import _json_loads_with_repair
9-
from agentscope.message import TextBlock, ToolUseBlock
8+
9+
1010
from loguru import logger
1111
from omegaconf import DictConfig
1212
from pydantic import BaseModel
@@ -33,15 +33,15 @@
3333
class AjetStandardLlmBridgeRequest(BaseModel):
3434
messages: List[Dict[str, str]]
3535
custom_sampling_params: dict = {}
36-
tools=[]
36+
tools: List = []
3737
request_id: str = ""
3838

3939
class AjetStandardLlmBridgeResponse(BaseModel):
4040
role: str = "assistant"
4141
request_id: str = ""
4242
content: str = ""
4343
tool_calls: List[Dict] = []
44-
tokens: List[TokenAndProb]
44+
tokens: List[TokenAndProb] = []
4545

4646

4747
# -------------------------------------------------------------------------------------

ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import asyncio
23
from typing import TYPE_CHECKING, Any, List, Callable, Literal, Type, Union
34
from loguru import logger
@@ -45,10 +46,19 @@ def __init__(
4546
episode_uuid: str,
4647
**kwargs,
4748
):
48-
self.base_url = "http://localhost:27788/v1"
49-
self.api_key = generate_auth_token(
49+
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
50+
assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set"
51+
base_url = f"http://localhost:{port}/v1"
52+
api_key = generate_auth_token(
5053
agent_name=agent_name,
5154
target_tag=target_tag,
5255
episode_uuid=episode_uuid,
5356
)
54-
self.model = "reserved_field"
57+
model = "reserved_field"
58+
59+
# Properly initialize the Pydantic BaseModel
60+
super().__init__(
61+
base_url=base_url,
62+
api_key=api_key,
63+
model=model,
64+
)

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ async def _service_loop(self):
108108
This design is for efficiency
109109
"""
110110

111-
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import TypeCompletionRequest
111+
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest
112112

113113
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
114114
assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set"
@@ -118,13 +118,13 @@ async def _service_loop(self):
118118
try:
119119
# Send initialization parameters
120120
# Sending as a list [agent_name, target_tag, episode_uuid] to match "input (a,b,c)" structure
121-
await websocket.send(f"episode_uuid:{self.episode_uuid}")
121+
await websocket.send(pickle.dumps(f"episode_uuid:{self.episode_uuid}"))
122122

123123
while not self.should_terminate:
124124

125125
try:
126126
# wait message from ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py
127-
parsed_msg: TypeCompletionRequest = pickle.loads(
127+
parsed_msg: InterchangeCompletionRequest = pickle.loads(
128128
await asyncio.wait_for(websocket.recv(decode=False), timeout=0.25)
129129
)
130130

0 commit comments

Comments
 (0)