Skip to content

Commit 666e8d3

Browse files
committed
fix interchange server boot
1 parent 75a4321 commit 666e8d3

File tree

4 files changed

+90
-6
lines changed

4 files changed

+90
-6
lines changed

ajet/default_config/ajet_default.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ ajet:
192192
fix_retokenization_drift: True
193193

194194
# log tool format check results
195-
log_tool_format_check: True
195+
log_tool_format_check: False
196196

197197
# log tool format check results
198198
log_tool_format_error_detail: False
@@ -281,7 +281,7 @@ ajet:
281281

282282

283283
# the experimental reverse proxy feature that allows `tuner.as_oai_baseurl_apikey` feature
284-
enable_experimental_interchange_server: False
284+
enable_experimental_interchange_server: True
285285
interchange_server:
286286
interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node)
287287
interchange_server_port: 'auto'

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio
102102
return result_object
103103

104104

105+
@app.get("/health")
106+
async def health():
107+
return {"status": "ok"}
108+
109+
105110
@app.post("/v1/chat/completions")
106111
async def chat_completions(request: Request, authorization: str = Header(None)):
107112
"""
@@ -167,6 +172,7 @@ def __init__(self, experiment_dir: str, port: int, num_fastapi_process: int = 2,
167172
self.max_fastapi_threads = max_fastapi_threads
168173

169174
def run(self):
175+
logger.info(f"Starting Interchange Server on port {self.port} with {self.num_fastapi_process} processes and {self.max_fastapi_threads} threads per process.")
170176
app = get_app(self.max_fastapi_threads)
171177
async def serve_with_monitor():
172178
# Start the server
@@ -215,15 +221,16 @@ def start_interchange_server(config) -> int:
215221
logger.error(f"Interchange server subprocess failed to start. Return code: {interchange_server.exitcode}")
216222
raise RuntimeError("Interchange server subprocess failed to start.")
217223
if time.time() - start_time > 30:
218-
logger.error("Interchange server subprocess failed to start within 30 seconds.")
219-
raise RuntimeError("Interchange server subprocess failed to start within 30 seconds.")
224+
msg = f"Interchange server subprocess failed to start within {time.time() - start_time} seconds."
225+
logger.error(msg)
226+
raise RuntimeError(msg)
220227
try:
221228
if httpx.get(health_url, timeout=0.5).status_code == 200:
222229
break
223230
except Exception:
224231
# keep waiting
225232
pass
226-
time.sleep(0.5)
233+
time.sleep(1)
227234

228235
# register a termination handler
229236
if DEBUG: logger.info(f"Interchange server subprocess started on port {port} (pid: {interchange_server.pid})")

docs/en/installation.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ AgentJet supports multiple backbones. Currently we have `verl` and `trinity` (re
8383
```
8484

8585
!!! warning "flash-attn Installation"
86-
`flash-attn` must be installed after other dependencies. To build faster, export `MAX_JOBS=${N_CPU}`, or ensure a healthy connection to GitHub to install pre-compiled wheels.
86+
- `flash-attn` must be installed **after** other dependencies.
87+
- Ensure a healthy connection to GitHub to install pre-compiled wheels.
88+
- If you find your machine spend a long time installing flash-attn, ensure a healthy connection to GitHub.
89+
- To build faster, export `MAX_JOBS=${N_CPU}`.
8790

8891

8992
=== "Trinity"
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import re
2+
from loguru import logger
3+
from agentscope.message import Msg
4+
from agentscope.agent import ReActAgent
5+
from agentscope.formatter import OpenAIChatFormatter
6+
from agentscope.model import OpenAIChatModel
7+
from agentscope.memory import InMemoryMemory
8+
from agentscope.tool import Toolkit, execute_python_code
9+
from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask
10+
11+
12+
def extract_final_answer(result) -> str:
13+
"""Extract the final answer from the agent's response."""
14+
try:
15+
if (
16+
hasattr(result, "metadata")
17+
and isinstance(result.metadata, dict)
18+
and "result" in result.metadata
19+
):
20+
return result.metadata["result"]
21+
if hasattr(result, "content"):
22+
if isinstance(result.content, dict) and "result" in result.content:
23+
return result.content["result"]
24+
return str(result.content)
25+
return str(result)
26+
except Exception as e:
27+
logger.warning(f"Extract final answer error: {e}. Raw: {result}")
28+
return str(result)
29+
30+
31+
system_prompt = """
32+
You are an agent specialized in solving math problems with tools.
33+
Please solve the math problem given to you.
34+
You can write and execute Python code to perform calculation or verify your answer.
35+
You should return your final answer within \\boxed{{}}.
36+
"""
37+
38+
39+
class MathToolWorkflow(Workflow): # ✨✨ inherit `Workflow` class
40+
name: str = "math_agent_workflow"
41+
42+
async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput:
43+
# run agentscope
44+
query = workflow_task.task.main_query
45+
self.toolkit = Toolkit()
46+
self.toolkit.register_tool_function(execute_python_code)
47+
48+
url_and_apikey = tuner.as_oai_baseurl_apikey()
49+
base_url = url_and_apikey.base_url
50+
api_key = url_and_apikey.api_key # the api key contain information, do not discard it
51+
model = OpenAIChatModel(
52+
model_name="whatever",
53+
client_args={"base_url": base_url},
54+
api_key=api_key,
55+
stream=False,
56+
)
57+
self.agent = ReActAgent(
58+
name="math_react_agent", sys_prompt=system_prompt,
59+
model=model, # ✨✨ compared with a normal agentscope agent, here is the difference!
60+
formatter=OpenAIChatFormatter(),
61+
toolkit=self.toolkit,
62+
memory=InMemoryMemory(), max_iters=2,
63+
)
64+
self.agent.set_console_output_enabled(False)
65+
msg = Msg("user", query, role="user")
66+
result = await self.agent.reply(msg)
67+
final_answer = extract_final_answer(result)
68+
69+
# compute reward
70+
reference_answer = workflow_task.task.metadata["answer"].split("####")[-1].strip()
71+
match = re.search(r"\\boxed\{([^}]*)\}", final_answer)
72+
if match: is_success = (match.group(1) == reference_answer)
73+
else: is_success = False
74+
return WorkflowOutput(reward=(1.0 if is_success else 0.0), metadata={"final_answer": final_answer})

0 commit comments

Comments
 (0)