Skip to content

Commit 2c05d11

Browse files
committed
optimize parallel performance with zmq
1 parent 1da86e2 commit 2c05d11

12 files changed

Lines changed: 190 additions & 250 deletions

File tree

ajet/backbone/trainer_verl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -834,10 +834,10 @@ def fit(self): # noqa: C901
834834
progress_bar.update(1)
835835
self.global_steps += 1
836836

837-
# when enabled oai request interchange, we need to clear the cache from time to time
838-
if self.config.ajet.enable_experimental_reverse_proxy:
839-
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear
840-
ensure_dat_interchange_server_cache_clear()
837+
# # when enabled oai request interchange, we need to clear the cache from time to time
838+
# if self.config.ajet.enable_experimental_reverse_proxy:
839+
# from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear
840+
# ensure_dat_interchange_server_cache_clear()
841841

842842
if is_last_step:
843843
pprint(f"Final validation metrics: {last_val_metrics}")

ajet/default_config/ajet_default.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ ajet:
99
# the experimental reverse proxy feature that allows `tuner.as_oai_baseurl_apikey` feature
1010
enable_experimental_reverse_proxy: False
1111

12+
# submit llm infer submit method
13+
llm_infer_submit_method: "async" # options: "sync", "async"
14+
1215
task_runner:
1316
wrapper_type: "asyncio-with-gc"
1417
wrapper_multiprocessing_timeout: 3600 # in seconds

ajet/task_rollout/async_llm_bridge.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262
self.tokenizer = tokenizer
6363
self.llm_mode = llm_mode
6464
self.max_llm_retries = max_llm_retries
65-
65+
self.tool_parser = Hermes2ProToolParser(self.tokenizer)
6666

6767
def get_llm_inference_fn_sync(self, sampling_params: dict = {}) -> Callable: # noqa: C901
6868

@@ -123,8 +123,7 @@ def llm_chat_verl(
123123
and ("</tool_call>" in decoded_text)
124124
and (not self.config.ajet.rollout.force_disable_toolcalls)
125125
):
126-
tool_parser = Hermes2ProToolParser(self.tokenizer)
127-
parsed_tool_calls = tool_parser.extract_tool_calls(decoded_text, None) # type: ignore
126+
parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore
128127
parsed_tool_calls = parsed_tool_calls.model_dump()
129128
if self.config.ajet.execute_test:
130129
_test_if_test_mode(
@@ -323,8 +322,8 @@ async def llm_chat_verl(
323322
and ("</tool_call>" in decoded_text)
324323
and (not self.config.ajet.rollout.force_disable_toolcalls)
325324
):
326-
tool_parser = Hermes2ProToolParser(self.tokenizer)
327-
parsed_tool_calls = tool_parser.extract_tool_calls(decoded_text, None) # type: ignore
325+
326+
parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore
328327
parsed_tool_calls = parsed_tool_calls.model_dump()
329328
if self.config.ajet.execute_test:
330329
_test_if_test_mode(
@@ -535,14 +534,15 @@ async def run_infer(
535534
# otherwise, for abnormal output, can still proceed, but we do not track output anymore
536535

537536
# run llm inference ✨
538-
# if sync:
539-
# llm_output = await asyncio.wait_for(
540-
# asyncio.to_thread(
541-
# self.llm_inference_fn, converted_message, custom_sampling_params, tools
542-
# ),
543-
# timeout=1800,
544-
# )
545-
llm_output = await asyncio.wait_for(self.llm_inference_fn(converted_message, custom_sampling_params, tools), timeout=1800)
537+
if self.config.ajet.llm_infer_submit_method == "sync":
538+
llm_output = await asyncio.wait_for(
539+
asyncio.to_thread(
540+
self.llm_inference_fn, converted_message, custom_sampling_params, tools
541+
),
542+
timeout=1800,
543+
)
544+
else:
545+
llm_output = await asyncio.wait_for(self.llm_inference_fn(converted_message, custom_sampling_params, tools), timeout=1800)
546546

547547

548548
# begin context tracking

ajet/task_rollout/single_worker.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,15 @@ def rollout_env_worker(
8484
(with validation overrides), and robust retry on transient failures.
8585
"""
8686
sampling_params = get_sample_params(mode, self.config)
87-
llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_async(
88-
sampling_params=sampling_params
89-
)
87+
88+
if self.config.ajet.llm_infer_submit_method == "sync":
89+
llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_sync(
90+
sampling_params=sampling_params
91+
)
92+
else:
93+
llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_async(
94+
sampling_params=sampling_params
95+
)
9096

9197
workflow_task = WorkflowTask(
9298
env_type=task.env_type,

ajet/tuner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def as_oai_baseurl_apikey(
114114
agent_name=agent_name,
115115
target_tag=target_tag,
116116
episode_uuid=self.context_tracker.episode_uuid,
117+
episode_contect_address=self.interchange_client.episode_contect_address,
117118
)
118119
return baseurl_apikey_model
119120

@@ -178,6 +179,7 @@ def _enable_experimental_interchange_server(self, llm_inference_fn):
178179
config=self.config,
179180
llm_inference_fn=llm_inference_fn,
180181
)
182+
return self.interchange_client.begin_service()
181183

182184

183185
def terminate_episode(self):

ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from openai.resources.chat.chat import Chat, AsyncChat
1313
from openai.resources.completions import AsyncCompletions
1414
from openai import OpenAI, AsyncOpenAI
15+
from ajet.utils.free_port import find_free_port
1516
from .experimental.as_oai_model_client import generate_auth_token
1617

1718
if TYPE_CHECKING:
@@ -43,6 +44,7 @@ def __init__(
4344
target_tag: str,
4445
agent_name: str,
4546
episode_uuid: str,
47+
episode_contect_address: str,
4648
**kwargs,
4749
):
4850
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
@@ -52,6 +54,7 @@ def __init__(
5254
agent_name=agent_name,
5355
target_tag=target_tag,
5456
episode_uuid=episode_uuid,
57+
episode_address=episode_contect_address,
5558
)
5659
model = "reserved_field"
5760

Lines changed: 76 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
import asyncio
3+
import atexit
34
import json
45
import threading
56
import os
@@ -9,14 +10,17 @@
910
from typing import Optional, List, Dict, Any, Union, TYPE_CHECKING
1011
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse
1112
from openai.types.chat.chat_completion import ChatCompletion
13+
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest
1214
from redis.exceptions import TimeoutError
13-
15+
from ajet.utils.free_port import find_free_port
16+
from ajet.utils.sington import ThreadExecutorLlmInferSingleton, ThreadExecutorSingleton
1417
from functools import cache
1518

1619
import pickle
1720
import httpx
1821
import zmq
1922
import logging
23+
2024
logging.getLogger("httpx").setLevel(logging.WARNING)
2125

2226
import base64
@@ -25,7 +29,10 @@
2529
if TYPE_CHECKING:
2630
from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker
2731

28-
def generate_auth_token(agent_name, target_tag, episode_uuid):
32+
DEBUG = False
33+
# DEBUG = True
34+
35+
def generate_auth_token(agent_name, target_tag, episode_uuid, episode_address):
2936
"""
3037
Generate a Base64-encoded auth_token from the given agent_name, target_tag, and episode_uuid.
3138
@@ -41,7 +48,8 @@ def generate_auth_token(agent_name, target_tag, episode_uuid):
4148
auth_data = {
4249
"agent_name": agent_name,
4350
"target_tag": target_tag,
44-
"episode_uuid": episode_uuid
51+
"episode_uuid": episode_uuid,
52+
"episode_address": episode_address,
4553
}
4654

4755
# Step 2: Convert the dictionary to a JSON string
@@ -68,12 +76,15 @@ def get_redis_connection_pool():
6876
)
6977
return pool
7078

71-
79+
@cache
7280
def get_redis_client():
7381
pool = get_redis_connection_pool()
7482
return redis.Redis(connection_pool=pool, decode_responses=False, encoding='utf-8')
7583

7684

85+
context = zmq.Context()
86+
atexit.register(context.term)
87+
7788
class InterchangeClient:
7889

7990
def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker", llm_inference_fn, config):
@@ -82,7 +93,10 @@ def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker
8293
self.llm_inference_fn = llm_inference_fn
8394
self.config = config
8495
self._should_terminate = False
85-
self.begin_service()
96+
97+
# self.episode_contect_address = f"tcp://localhost:{find_free_port()}"
98+
self.ipc_path = f"/tmp/ajet/{self.episode_uuid}.sock"
99+
self.episode_contect_address = f"ipc://{self.ipc_path}"
86100

87101

88102
async def llm_infer(
@@ -124,127 +138,78 @@ def begin_service(self):
124138
"""
125139
Starts the SSE service loop.
126140
"""
127-
t = threading.Thread(target=self._begin_service_threading, daemon=True)
128-
t.start()
129-
130-
131-
def _handle_service_request(self, msg: bytes, sem: threading.Semaphore):
132-
"""handle a single service request in its own thread
133-
"""
134-
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest
135-
logger.info(f"[client] {self.episode_uuid} | inside _handle_service_request")
136-
redis_client = get_redis_client()
137-
logger.info(f"[client] {self.episode_uuid} | get_redis_client")
138-
data_as_json = ""
139-
topic = ""
140-
try:
141-
data_as_json = json.loads(pickle.loads(msg))
142-
timeline_uuid = data_as_json["timeline_uuid"]
143-
topic = f"stream:timeline:{timeline_uuid}"
144-
logger.info(f"[client] {self.episode_uuid} | json.loads(pickle.loads(msg))")
145-
146-
147-
if "health_check" in data_as_json and data_as_json["health_check"]:
148-
# logger.info(f"Received health check for timeline_uuid: {timeline_uuid}")
149-
result = '{"health_check_ok": "True"}'
150-
# logger.success(f"Health check OK for timeline_uuid: {timeline_uuid}")
151-
else:
152-
parsed_msg = InterchangeCompletionRequest(**data_as_json)
153-
# start llm request
154-
result = asyncio.run(self.llm_infer(
155-
req=parsed_msg.completion_request,
156-
timeline_uuid=parsed_msg.timeline_uuid,
157-
agent_name=parsed_msg.agent_name,
158-
target_tag=parsed_msg.target_tag,
159-
episode_uuid=parsed_msg.episode_uuid,
160-
)).model_dump_json()
161-
# logger.success(f"LLM inference completed for timeline_uuid: {timeline_uuid}")
162-
logger.info(f"[client] {self.episode_uuid} | result = asyncio.run(self.llm_infer")
163-
# send result back
164-
bytes_arr = pickle.dumps(result)
165-
logger.info(f"[client] {self.episode_uuid} | bytes_arr = pickle.dumps(result)")
166-
redis_client.xadd(topic, {'data': bytes_arr})
167-
redis_client.expire(topic, 600) # expire after 10 mins
168-
logger.info(f"[client] {self.episode_uuid} | redis_client.xadd(topic, ...)")
169-
170-
except Exception as e:
171-
err = f"[ERR]: Error when processing data: {data_as_json} Error: {e}"
172-
result = err
173-
logger.error(err)
174-
if topic:
175-
redis_client.xadd(topic, {'data': pickle.dumps(result)})
176-
redis_client.expire(topic, 600)
177-
178-
finally:
179-
# release semaphore when done
180-
sem.release()
181-
redis_client.close()
141+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...")
142+
self.socket = context.socket(zmq.REP)
143+
self.socket.bind(f"{self.episode_contect_address}")
144+
self.socket.setsockopt(zmq.RCVTIMEO, 2*1000) # 60 秒超时
182145

146+
self.executor = ThreadExecutorSingleton().get_executor()
147+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Submitting _begin_service_threading to executor...")
148+
future = self.executor.submit(self._begin_service_threading)
149+
time.sleep(1)
150+
while future._state == 'PENDING':
151+
time.sleep(1)
152+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Future ready...")
183153

154+
# t = threading.Thread(target=self._begin_service_threading, daemon=True)
155+
# t.start()
156+
return self.episode_contect_address
184157

185158

186159
def _begin_service_threading(self):
187160
"""begin listening for service requests in a threading model
188161
"""
189-
# logger.success(f"InterchangeClient starting for episode_uuid:{self.episode_uuid}")
190-
# debug_logs = []
191-
begin_time = time.time()
192-
logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...")
193-
redis_client = get_redis_client()
194-
episode_stream = f"stream:episode:{self.episode_uuid}"
195-
196-
sem = threading.Semaphore(8) # 4 concurrent requests max
197-
logger.info(f"[client] {self.episode_uuid} | Listening to stream {episode_stream}, waiting for messages...")
198162

199-
last_id = '0-0'
200-
is_init = True
163+
begin_time = time.time()
164+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting ZMQ socket bind complete")
201165

202166
try:
203167
while not self.should_terminate:
204-
# wait for a new message
205-
logger.info(f"[client] {self.episode_uuid} | Waiting for new message on stream {episode_stream}...")
206168

207-
# Check messages
208169
try:
209-
response = redis_client.xread({episode_stream: last_id}, count=1, block=30*1000) # block for 30 seconds (30000 ms)
210-
except TimeoutError:
211-
time.sleep(5)
212-
continue
213-
214-
timepassed = time.time() - begin_time
215-
216-
if not response:
217-
if is_init and timepassed > 30:
218-
logger.warning(f"[client] Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...")
170+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() has begun")
171+
message = self.socket.recv_string()
172+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() is done")
173+
except zmq.Again as e:
174+
if self.should_terminate:
175+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | episode over")
176+
break
177+
timepassed = time.time() - begin_time
178+
if timepassed > 60:
179+
logger.warning(f"[client] {self.episode_uuid} | Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...")
219180
continue
220181

221-
# Got message
222-
is_init = False
223-
logger.info(f"[client] {self.episode_uuid} | get message...")
224-
225-
stream_result = response[0]
226-
messages = stream_result[1]
227-
msg_id, data_dict = messages[0]
228-
229-
last_id = msg_id
230-
231-
if b'data' in data_dict:
232-
msg: bytes = data_dict[b'data']
233-
else:
234-
logger.error(f"Missing 'data' in stream message {msg_id}")
235-
continue
236-
237-
# are we free to spawn a new thread?
238-
sem.acquire()
239-
logger.info(f"[client] {self.episode_uuid} | sem acquire...")
240-
# begin a new thread to handle this request
241-
threading.Thread(target=self._handle_service_request, args=(msg, sem), daemon=True).start()
242-
182+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | before json.loads(message)")
183+
data_as_json = json.loads(message)
184+
parsed_msg = InterchangeCompletionRequest(**data_as_json)
243185

244-
except KeyboardInterrupt:
245-
return
186+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | before asyncio run self.llm_infer")
246187

188+
try:
189+
loop = asyncio.get_running_loop()
190+
except:
191+
loop = asyncio.new_event_loop()
192+
executor = ThreadExecutorLlmInferSingleton().get_executor()
193+
future = loop.run_in_executor(
194+
executor, # executor
195+
asyncio.run,
196+
self.llm_infer(
197+
req=parsed_msg.completion_request,
198+
timeline_uuid=parsed_msg.timeline_uuid,
199+
agent_name=parsed_msg.agent_name,
200+
target_tag=parsed_msg.target_tag,
201+
episode_uuid=parsed_msg.episode_uuid,
202+
)
203+
)
204+
result = loop.run_until_complete(future).model_dump_json() # type: ignore
205+
206+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | before send_string")
207+
self.socket.send_string(result)
208+
except:
209+
logger.exception(f"[client] {self.episode_uuid} | Exception occurred in service loop.")
247210
finally:
248-
redis_client.delete(episode_stream)
249-
redis_client.close()
250-
211+
self.socket.close()
212+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | ZMQ socket closed, service loop terminated.")
213+
if os.path.exists(self.ipc_path):
214+
os.remove(self.ipc_path)
215+
if DEBUG: logger.info(f"[client] {self.episode_uuid} | IPC socket file {self.ipc_path} removed.")

0 commit comments

Comments
 (0)