11
22import asyncio
3+ import atexit
34import json
45import threading
56import os
910from typing import Optional , List , Dict , Any , Union , TYPE_CHECKING
1011from vllm .entrypoints .openai .protocol import ChatCompletionRequest , ChatCompletionResponse
1112from openai .types .chat .chat_completion import ChatCompletion
13+ from ajet .tuner_lib .weight_tuner .experimental .as_oai_model_server import InterchangeCompletionRequest
1214from redis .exceptions import TimeoutError
13-
15+ from ajet .utils .free_port import find_free_port
16+ from ajet .utils .sington import ThreadExecutorLlmInferSingleton , ThreadExecutorSingleton
1417from functools import cache
1518
1619import pickle
1720import httpx
1821import zmq
1922import logging
23+
2024logging .getLogger ("httpx" ).setLevel (logging .WARNING )
2125
2226import base64
2529if TYPE_CHECKING :
2630 from ajet .context_tracker .multiagent_tracking import MultiAgentContextTracker
2731
28- def generate_auth_token (agent_name , target_tag , episode_uuid ):
32+ DEBUG = False
33+ # DEBUG = True
34+
35+ def generate_auth_token (agent_name , target_tag , episode_uuid , episode_address ):
2936 """
3037 Generate a Base64-encoded auth_token from the given agent_name, target_tag, and episode_uuid.
3138
@@ -41,7 +48,8 @@ def generate_auth_token(agent_name, target_tag, episode_uuid):
4148 auth_data = {
4249 "agent_name" : agent_name ,
4350 "target_tag" : target_tag ,
44- "episode_uuid" : episode_uuid
51+ "episode_uuid" : episode_uuid ,
52+ "episode_address" : episode_address ,
4553 }
4654
4755 # Step 2: Convert the dictionary to a JSON string
@@ -68,12 +76,15 @@ def get_redis_connection_pool():
6876 )
6977 return pool
7078
71-
79+ @ cache
7280def get_redis_client ():
7381 pool = get_redis_connection_pool ()
7482 return redis .Redis (connection_pool = pool , decode_responses = False , encoding = 'utf-8' )
7583
7684
85+ context = zmq .Context ()
86+ atexit .register (context .term )
87+
7788class InterchangeClient :
7889
7990 def __init__ (self , episode_uuid : str , context_tracker : "MultiAgentContextTracker" , llm_inference_fn , config ):
@@ -82,7 +93,10 @@ def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker
8293 self .llm_inference_fn = llm_inference_fn
8394 self .config = config
8495 self ._should_terminate = False
85- self .begin_service ()
96+
97+ # self.episode_contect_address = f"tcp://localhost:{find_free_port()}"
98+ self .ipc_path = f"/tmp/ajet/{ self .episode_uuid } .sock"
99+ self .episode_contect_address = f"ipc://{ self .ipc_path } "
86100
87101
88102 async def llm_infer (
@@ -124,127 +138,78 @@ def begin_service(self):
124138 """
125139 Starts the SSE service loop.
126140 """
127- t = threading .Thread (target = self ._begin_service_threading , daemon = True )
128- t .start ()
129-
130-
131- def _handle_service_request (self , msg : bytes , sem : threading .Semaphore ):
132- """handle a single service request in its own thread
133- """
134- from ajet .tuner_lib .weight_tuner .experimental .as_oai_model_server import InterchangeCompletionRequest
135- logger .info (f"[client] { self .episode_uuid } | inside _handle_service_request" )
136- redis_client = get_redis_client ()
137- logger .info (f"[client] { self .episode_uuid } | get_redis_client" )
138- data_as_json = ""
139- topic = ""
140- try :
141- data_as_json = json .loads (pickle .loads (msg ))
142- timeline_uuid = data_as_json ["timeline_uuid" ]
143- topic = f"stream:timeline:{ timeline_uuid } "
144- logger .info (f"[client] { self .episode_uuid } | json.loads(pickle.loads(msg))" )
145-
146-
147- if "health_check" in data_as_json and data_as_json ["health_check" ]:
148- # logger.info(f"Received health check for timeline_uuid: {timeline_uuid}")
149- result = '{"health_check_ok": "True"}'
150- # logger.success(f"Health check OK for timeline_uuid: {timeline_uuid}")
151- else :
152- parsed_msg = InterchangeCompletionRequest (** data_as_json )
153- # start llm request
154- result = asyncio .run (self .llm_infer (
155- req = parsed_msg .completion_request ,
156- timeline_uuid = parsed_msg .timeline_uuid ,
157- agent_name = parsed_msg .agent_name ,
158- target_tag = parsed_msg .target_tag ,
159- episode_uuid = parsed_msg .episode_uuid ,
160- )).model_dump_json ()
161- # logger.success(f"LLM inference completed for timeline_uuid: {timeline_uuid}")
162- logger .info (f"[client] { self .episode_uuid } | result = asyncio.run(self.llm_infer" )
163- # send result back
164- bytes_arr = pickle .dumps (result )
165- logger .info (f"[client] { self .episode_uuid } | bytes_arr = pickle.dumps(result)" )
166- redis_client .xadd (topic , {'data' : bytes_arr })
167- redis_client .expire (topic , 600 ) # expire after 10 mins
168- logger .info (f"[client] { self .episode_uuid } | redis_client.xadd(topic, ...)" )
169-
170- except Exception as e :
171- err = f"[ERR]: Error when processing data: { data_as_json } Error: { e } "
172- result = err
173- logger .error (err )
174- if topic :
175- redis_client .xadd (topic , {'data' : pickle .dumps (result )})
176- redis_client .expire (topic , 600 )
177-
178- finally :
179- # release semaphore when done
180- sem .release ()
181- redis_client .close ()
141+ if DEBUG : logger .info (f"[client] { self .episode_uuid } | Starting InterchangeClient service loop..." )
142+ self .socket = context .socket (zmq .REP )
143+ self .socket .bind (f"{ self .episode_contect_address } " )
144+ self .socket .setsockopt (zmq .RCVTIMEO , 2 * 1000 ) # 60 秒超时
182145
146+ self .executor = ThreadExecutorSingleton ().get_executor ()
147+ if DEBUG : logger .info (f"[client] { self .episode_uuid } | Submitting _begin_service_threading to executor..." )
148+ future = self .executor .submit (self ._begin_service_threading )
149+ time .sleep (1 )
150+ while future ._state == 'PENDING' :
151+ time .sleep (1 )
152+ if DEBUG : logger .info (f"[client] { self .episode_uuid } | Future ready..." )
183153
154+ # t = threading.Thread(target=self._begin_service_threading, daemon=True)
155+ # t.start()
156+ return self .episode_contect_address
184157
185158
186159 def _begin_service_threading (self ):
187160 """begin listening for service requests in a threading model
188161 """
189- # logger.success(f"InterchangeClient starting for episode_uuid:{self.episode_uuid}")
190- # debug_logs = []
191- begin_time = time .time ()
192- logger .info (f"[client] { self .episode_uuid } | Starting InterchangeClient service loop..." )
193- redis_client = get_redis_client ()
194- episode_stream = f"stream:episode:{ self .episode_uuid } "
195-
196- sem = threading .Semaphore (8 ) # 4 concurrent requests max
197- logger .info (f"[client] { self .episode_uuid } | Listening to stream { episode_stream } , waiting for messages..." )
198162
199- last_id = '0-0'
200- is_init = True
163+ begin_time = time . time ()
164+ if DEBUG : logger . info ( f"[client] { self . episode_uuid } | Starting ZMQ socket bind complete" )
201165
202166 try :
203167 while not self .should_terminate :
204- # wait for a new message
205- logger .info (f"[client] { self .episode_uuid } | Waiting for new message on stream { episode_stream } ..." )
206168
207- # Check messages
208169 try :
209- response = redis_client . xread ({ episode_stream : last_id }, count = 1 , block = 30 * 1000 ) # block for 30 seconds (30000 ms )
210- except TimeoutError :
211- time . sleep ( 5 )
212- continue
213-
214- timepassed = time . time () - begin_time
215-
216- if not response :
217- if is_init and timepassed > 30 :
218- logger .warning (f"[client] Still waiting for first message... (time passed { timepassed } ) for episode_uuid:{ self .episode_uuid } ..." )
170+ if DEBUG : logger . info ( f"[client] { self . episode_uuid } | socket.recv_string() has begun" )
171+ message = self . socket . recv_string ()
172+ if DEBUG : logger . info ( f"[client] { self . episode_uuid } | socket.recv_string() is done" )
173+ except zmq . Again as e :
174+ if self . should_terminate :
175+ if DEBUG : logger . info ( f"[client] { self . episode_uuid } | episode over" )
176+ break
177+ timepassed = time . time () - begin_time
178+ if timepassed > 60 :
179+ logger .warning (f"[client] { self . episode_uuid } | Still waiting for first message... (time passed { timepassed } ) for episode_uuid:{ self .episode_uuid } ..." )
219180 continue
220181
221- # Got message
222- is_init = False
223- logger .info (f"[client] { self .episode_uuid } | get message..." )
224-
225- stream_result = response [0 ]
226- messages = stream_result [1 ]
227- msg_id , data_dict = messages [0 ]
228-
229- last_id = msg_id
230-
231- if b'data' in data_dict :
232- msg : bytes = data_dict [b'data' ]
233- else :
234- logger .error (f"Missing 'data' in stream message { msg_id } " )
235- continue
236-
237- # are we free to spawn a new thread?
238- sem .acquire ()
239- logger .info (f"[client] { self .episode_uuid } | sem acquire..." )
240- # begin a new thread to handle this request
241- threading .Thread (target = self ._handle_service_request , args = (msg , sem ), daemon = True ).start ()
242-
182+ if DEBUG : logger .info (f"[client] { self .episode_uuid } | before json.loads(message)" )
183+ data_as_json = json .loads (message )
184+ parsed_msg = InterchangeCompletionRequest (** data_as_json )
243185
244- except KeyboardInterrupt :
245- return
186+ if DEBUG : logger .info (f"[client] { self .episode_uuid } | before asyncio run self.llm_infer" )
246187
188+ try :
189+ loop = asyncio .get_running_loop ()
190+ except :
191+ loop = asyncio .new_event_loop ()
192+ executor = ThreadExecutorLlmInferSingleton ().get_executor ()
193+ future = loop .run_in_executor (
194+ executor , # executor
195+ asyncio .run ,
196+ self .llm_infer (
197+ req = parsed_msg .completion_request ,
198+ timeline_uuid = parsed_msg .timeline_uuid ,
199+ agent_name = parsed_msg .agent_name ,
200+ target_tag = parsed_msg .target_tag ,
201+ episode_uuid = parsed_msg .episode_uuid ,
202+ )
203+ )
204+ result = loop .run_until_complete (future ).model_dump_json () # type: ignore
205+
206+ if DEBUG : logger .info (f"[client] { self .episode_uuid } | before send_string" )
207+ self .socket .send_string (result )
208+ except :
209+ logger .exception (f"[client] { self .episode_uuid } | Exception occurred in service loop." )
247210 finally :
248- redis_client .delete (episode_stream )
249- redis_client .close ()
250-
211+ self .socket .close ()
212+ if DEBUG : logger .info (f"[client] { self .episode_uuid } | ZMQ socket closed, service loop terminated." )
213+ if os .path .exists (self .ipc_path ):
214+ os .remove (self .ipc_path )
215+ if DEBUG : logger .info (f"[client] { self .episode_uuid } | IPC socket file { self .ipc_path } removed." )
0 commit comments