Skip to content

Commit 0ab6d6b

Browse files
committed
Refactor MultiAgentContextTracker and ExtendedMessage for improved message handling; update oai_model_client to use asyncio for ZMQ communication; adjust math_agent.yaml model path; modify agent_roll_v3.py for batch size and experiment naming.
1 parent da51a13 commit 0ab6d6b

5 files changed

Lines changed: 55 additions & 62 deletions

File tree

ajet/context_tracker/multiagent_tracking.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
166166
else:
167167
author = "env"
168168

169+
# extract content block from openai-competible messages and convert to ExtendedMessage
169170
timeline += [
170171
ExtendedMessage(
171172
author=author,
@@ -235,6 +236,7 @@ def step_track(
235236

236237
tool_calls = self.detect_tool_call_madness(llm_output)
237238

239+
# add llm_output to timeline and save
238240
llm_ext_msg = ExtendedMessage(
239241
author="llm",
240242
role="assistant",

ajet/schema/extended_msg.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101
self.manual_loss_mask_override = []
102102
self.lack_normal_eos = False
103103

104-
self.generate_content_for_compare(tokenizer=None)
104+
self.generate_content_for_compare(content = self.content)
105105

106106
self.eos_token_id = tokenizer.eos_token_id
107107

@@ -173,7 +173,7 @@ def content_for_compare(self):
173173
if self._content_for_compare == "":
174174
if not self.tool_calls:
175175
logger.exception("content_for_compare is not set, or previous llm output is empty!")
176-
self._content_for_compare
176+
# self._content_for_compare
177177
return self._content_for_compare
178178

179179
@property
@@ -185,9 +185,8 @@ def need_training(self):
185185
), f"author {self.author} is not identified"
186186
return self.author in NEED_TRAIN_AUTHORS
187187

188-
def generate_content_for_compare(self, tokenizer):
189-
_content: str = self.content
190-
self._content_for_compare = _content
188+
def generate_content_for_compare(self, content):
189+
self._content_for_compare = content
191190

192191
def get_loss_mask(self, blackout_token_combo):
193192
if self.need_training:
@@ -302,6 +301,7 @@ def merge_tool_group(group, tokenizer):
302301
)
303302
merged_content = merged_content[len("<tool_response>\n") :]
304303
merged_content = merged_content[: -len("</tool_response>\n")]
304+
# create merged tool response block
305305
merged = ExtendedMessage(
306306
author=msg0.author,
307307
role=msg0.role,

ajet/tuner_lib/experimental/oai_model_client.py

Lines changed: 42 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,19 @@
55
import os
66
import time
77
import zmq
8-
import json
8+
import zmq.asyncio
99

1010
from loguru import logger
1111
from typing import TYPE_CHECKING
1212
from ajet.tuner_lib.experimental.oai_model_server import InterchangeCompletionRequest
13-
from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor
13+
from ajet.utils.thread_executors import SharedInterchangeThreadExecutor
1414
from ajet.tuner_lib.experimental.interchange_utils import get_zmq_socket
1515
from ajet.tuner_lib.experimental.interchange_utils import DEBUG
1616

1717
if TYPE_CHECKING:
1818
pass
1919

20-
context = zmq.Context()
20+
context = zmq.asyncio.Context()
2121
atexit.register(context.term)
2222

2323
if TYPE_CHECKING:
@@ -72,11 +72,10 @@ def begin_service(self):
7272
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...")
7373
self.socket = context.socket(zmq.REP)
7474
self.socket.bind(f"{self.episode_contect_address}")
75-
self.socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 1 second timeout for REP
7675

7776
self.executor = SharedInterchangeThreadExecutor(self.max_inference_tracker_threads).get_shared_executor()
78-
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Submitting _begin_service_threading to executor...")
79-
future = self.executor.submit(self._begin_service_threading)
77+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Submitting _run_service_loop to executor...")
78+
future = self.executor.submit(self._run_service_loop)
8079

8180
# wait till service begin running
8281
wait_time = 1
@@ -94,26 +93,33 @@ def begin_service(self):
9493
return self.episode_contect_address
9594

9695

97-
def _begin_service_threading(self):
98-
"""begin listening for service requests in a threading model
96+
def _run_service_loop(self):
97+
"""Runs a dedicated asyncio event loop for this episode's zmq service.
98+
"""
99+
loop = asyncio.new_event_loop()
100+
asyncio.set_event_loop(loop)
101+
try:
102+
loop.run_until_complete(self._begin_service_async())
103+
finally:
104+
loop.close()
105+
asyncio.set_event_loop(None)
106+
107+
108+
async def _begin_service_async(self):
109+
"""begin listening for service requests using zmq.asyncio
99110
"""
100111

101112
begin_time = time.time()
102113
ever_receive_anything = False
103114
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting ZMQ socket bind complete")
104115

116+
poller = zmq.asyncio.Poller()
117+
poller.register(self.socket, zmq.POLLIN)
118+
105119
try:
106120
while not self.should_hard_terminate:
107-
try:
108-
109-
# <wait for>:
110-
# <from_sourcefile>: ajet/tuner_lib/experimental/oai_model_server.py
111-
# <from_code>: socket.send_string(int_req.model_dump_json())
112-
# <expect>: InterchangeCompletionRequest object in JSON string format
113-
message = self.socket.recv_string()
114-
115-
ever_receive_anything = True
116-
except zmq.Again as e:
121+
events = dict(await poller.poll(timeout=1000)) # 1 second
122+
if self.socket not in events:
117123
if self.should_hard_terminate:
118124
# abort_episode()
119125
if DEBUG: logger.info(f"[client] {self.episode_uuid} | episode over")
@@ -123,51 +129,35 @@ def _begin_service_threading(self):
123129
if DEBUG: logger.warning(f"[client] {self.episode_uuid} | Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...")
124130
continue
125131

132+
# <wait for>:
133+
# <from_sourcefile>: ajet/tuner_lib/experimental/oai_model_server.py
134+
# <from_code>: socket.send_string(int_req.model_dump_json())
135+
# <expect>: InterchangeCompletionRequest object in JSON string format
136+
message = await self.socket.recv_string()
137+
ever_receive_anything = True
138+
126139
# parse the incoming request
127140
if DEBUG: logger.info(f"[client] {self.episode_uuid} | before json.loads(message)")
128141
data_as_json = json.loads(message)
129142
parsed_msg = InterchangeCompletionRequest(**data_as_json)
130143

131-
# begin to run the llm request, monitored by context tracker
132-
# we re-use previously created thread for best performance
133-
if DEBUG: logger.info(f"[client] {self.episode_uuid} | before asyncio run self.llm_infer")
134-
135-
# Check if there's a running event loop
136-
try:
137-
loop = asyncio.get_running_loop()
138-
created_new_loop = False
139-
except RuntimeError:
140-
# No running loop, create a new one
141-
loop = asyncio.new_event_loop()
142-
asyncio.set_event_loop(loop)
143-
created_new_loop = True
144-
145-
try:
146-
context_tracker_executor = SharedInferenceTrackerThreadExecutor(self.max_inference_tracker_threads).get_shared_executor()
147-
future = loop.run_in_executor(
148-
context_tracker_executor,
149-
asyncio.run,
150-
self.llm_proxy_with_tracker.chat_completion_request(
151-
req=parsed_msg.completion_request,
152-
timeline_uuid=parsed_msg.timeline_uuid,
153-
agent_name=parsed_msg.agent_name,
154-
target_tag=parsed_msg.target_tag,
155-
episode_uuid=parsed_msg.episode_uuid,
156-
)
157-
)
158-
result = loop.run_until_complete(future).model_dump_json() # type: ignore
159-
finally:
160-
# Clean up the event loop if we created it
161-
if created_new_loop:
162-
loop.close()
163-
asyncio.set_event_loop(None)
144+
# run the llm request, monitored by context tracker
145+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | before awaiting self.llm_infer")
146+
response = await self.llm_proxy_with_tracker.chat_completion_request(
147+
req=parsed_msg.completion_request,
148+
timeline_uuid=parsed_msg.timeline_uuid,
149+
agent_name=parsed_msg.agent_name,
150+
target_tag=parsed_msg.target_tag,
151+
episode_uuid=parsed_msg.episode_uuid,
152+
)
153+
result = response.model_dump_json()
164154

165155
if DEBUG: logger.info(f"[client] {self.episode_uuid} | before send_string (send llm call result)")
166156

167157
# <send to>
168158
# <to_sourcefile>: ajet/tuner_lib/experimental/oai_model_server.py
169159
# <to_code>: result_str = socket.recv_string()
170-
self.socket.send_string(result)
160+
await self.socket.send_string(result)
171161

172162
if DEBUG: logger.info(f"[client] {self.episode_uuid} | after send_string (send llm call result)")
173163
except:

tutorial/example_math_agent/math_agent.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ ajet:
1515

1616
model:
1717
# ✨✨✨✨ set the model to be trained
18-
path: Qwen/Qwen2.5-7B
18+
path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct
1919

2020
rollout:
21-
user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ write and select workflow
21+
# user_workflow: "tutorial.example_math_agent.math_agent->ExampleMathLearn" # ✨✨✨✨ write and select workflow
2222
# user_workflow: "tutorial.example_math_agent.math_agent_langchain->ExampleMathLearn" # ✨if you prefer langchain version
2323
# user_workflow: "tutorial/example_math_agent/math_agent_oai_sdk.py->ExampleMathLearn_Simple_NoToolCall" # ✨if you prefer openai sdk version without toolcall
24-
# user_workflow: "tutorial/example_math_agent/math_agent_oai_sdk.py->ExampleMathLearn" # ✨if you prefer openai sdk version with toolcall
24+
user_workflow: "tutorial/example_math_agent/math_agent_oai_sdk.py->ExampleMathLearn" # ✨if you prefer openai sdk version with toolcall
2525
# user_workflow: "tutorial/example_math_agent/math_agent_raw_http.py->ExampleMathLearn" # ✨if you do not want to use any agentic framwork at all
2626
# user_workflow: "tutorial/example_math_agent/math_agent_simplify.py->MathToolWorkflow" # ✨if you prefer to compute reward inside workflow
2727
temperature: 1.0

tutorial/opencode_build_aime/agent_roll_v3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,17 @@
2525
REMOTE_MODEL_PATH = os.getenv("REMOTE_MODEL_PATH", "/mnt/data_cpfs/xielipeng.xlp/models/Qwen3-14B")
2626
BATCH_SIZE = 32
2727
NUM_REPEAT = 8
28-
MINI_BATCH_NUM = 2
28+
MINI_BATCH_NUM = 1
2929
ajet_job = AgentJetJob(
3030
algorithm="grpo",
31-
experiment_name="aime_swarm_14b_v33_2",
31+
experiment_name="aime_swarm_14b_v33_ppoepoch4",
3232
max_env_worker=128,
3333
n_gpu=8,
3434
model=REMOTE_MODEL_PATH,
3535
batch_size=BATCH_SIZE,
3636
swarm_mode_sample_collection_method="rollout_until_finish_enough_non_dummy_tasks",
3737
num_repeat=NUM_REPEAT,
38+
ppo_epochs=4,
3839
mini_batch_num=MINI_BATCH_NUM,
3940
logging="swanlab",
4041
max_prompt_length=3000,

0 commit comments

Comments
 (0)