1111
1212import asyncio
1313from functools import cache
14+ from contextlib import asynccontextmanager
1415from multiprocessing import Process
1516import threading
1617from concurrent .futures import ThreadPoolExecutor
@@ -72,178 +73,174 @@ def get_redis_client():
7273 pool = get_redis_connection_pool ()
7374 return redis .Redis (connection_pool = pool , decode_responses = False , encoding = 'utf-8' )
7475
75-
7676# Create FastAPI app
77- app = FastAPI (title = "AJet Interchange Endpoint" )
78-
79- @app .on_event ("startup" )
80- async def startup_event ():
81- app .state .executor = ThreadPoolExecutor (max_workers = 512 )
82-
83- @app .on_event ("shutdown" )
84- async def shutdown_event ():
85- app .state .executor .shutdown ()
86-
87-
88- def _begin_handle_chat_completion (int_req , episode_uuid , timeline_uuid , client_offline : asyncio .Event ):
89- """ run this in thread to avoid blocking main event loop
90- """
91- logger .info (f"episode_uuid: { episode_uuid } | Received new chat completion request for episode_uuid: { episode_uuid } , timeline_uuid: { timeline_uuid } (inside thread)" )
92-
93- redis_client = get_redis_client ()
94- episode_stream = f"stream:episode:{ episode_uuid } "
95- timeline_stream = f"stream:timeline:{ timeline_uuid } "
96-
97- max_wait_time = 600 # 10 minutes timeout
98- try :
99- logger .info (f"episode_uuid: { episode_uuid } | redis_client.xadd int_req " )
100- redis_client .xadd (episode_stream , {'data' : pickle .dumps (int_req .model_dump_json ())})
101- logger .info (f"episode_uuid: { episode_uuid } | redis_client.xadd int_req end" )
102-
103- # record start
104- begin_time = time .time ()
105-
106- # wait for result
107- last_id = '0-0'
108- while not client_offline .is_set ():
109- timepassed = time .time () - begin_time
110- if timepassed > max_wait_time :
111- return HTTPException (status_code = 504 , detail = "Request timeout" )
112- try :
113- logger .info (f"episode_uuid: { episode_uuid } | redis_client.xread block=30000" )
114- # Block for 30 seconds to allow loop to check client_offline
115- response = redis_client .xread ({timeline_stream : last_id }, count = 1 , block = 30 * 1000 ) # block for 30 seconds
116- logger .info (f"episode_uuid: { episode_uuid } | redis_client.xread after" )
117-
118- if not response :
119- if timepassed > 60 :
120- logger .warning (f"episode_uuid: { episode_uuid } | LLM client infer still waiting... (time passed { timepassed } ) for episode_uuid:{ episode_uuid } , timeline_uuid:{ timeline_uuid } ..." )
121- continue
77+ SERVER_SHUTDOWN_EVENT = threading .Event ()
12278
123- # response format: [[stream_name, [[message_id, data_dict]]]]
124- stream_result = response [0 ]
125- messages = stream_result [1 ]
126- message_id , data_dict = messages [0 ]
127-
128- logger .info (f"episode_uuid: { episode_uuid } | successfully get message from redis stream" )
129-
130- # Retrieve data, decode_responses=False so keys/values are bytes
131- if b'data' in data_dict :
132- data_bytes = data_dict [b'data' ]
133- else :
134- logger .error (f"Missing 'data' field in stream message: { data_dict } " )
135- continue
136-
137- result_object_str = pickle .loads (data_bytes )
138-
139- if result_object_str .startswith ('[ERR]' ):
140- return HTTPException (status_code = 500 , detail = "Error response, " + result_object_str )
141- result_object = ChatCompletion (** json .loads (result_object_str ))
142-
143- # Cleanup stream
144- redis_client .delete (timeline_stream )
145-
146- return result_object
147-
148- except TimeoutError :
149- logger .info (f"episode_uuid: { episode_uuid } | still waiting, (time passed { timepassed } ) for result for episode_uuid:{ episode_uuid } , timeline_uuid:{ timeline_uuid } ..." )
150- continue
151- except Exception as e :
152- logger .error (f"Error reading from stream: { e } " )
153- if timepassed > max_wait_time :
154- raise e
155- time .sleep (1 )
156-
157- except Exception as e :
158- logger .error (f"Communication failed: { e } " )
159- return HTTPException (status_code = 500 , detail = f"Communication failed: { e } " )
160-
161- finally :
162- redis_client .close ()
163-
164-
165- @app .post ("/v1/chat/completions" )
166- async def chat_completions (request : Request , authorization : str = Header (None )):
167- """
168- OpenAI-compatible chat completions endpoint.
169- Receives ChatCompletionRequest and returns ChatCompletion.
170- """
171- # Parse authorization header (base64 encoded JSON)
172- if not authorization :
173- return HTTPException (status_code = 401 , detail = "Missing authorization header" )
174-
175- try :
176- # Remove "Bearer " prefix if present
177- auth_token = authorization .replace ("Bearer " , "" ).replace ("bearer " , "" )
178- decoded = base64 .b64decode (auth_token ).decode ('utf-8' )
179- auth_data = json .loads (decoded )
180-
181- agent_name = auth_data .get ("agent_name" )
182- target_tag = auth_data .get ("target_tag" )
183- episode_uuid = auth_data .get ("episode_uuid" )
184-
185- if not all ([agent_name , target_tag , episode_uuid ]):
186- return HTTPException (status_code = 401 , detail = "Invalid authorization data" )
187- except Exception as e :
188- return HTTPException (status_code = 401 , detail = f"Invalid authorization header: { str (e )} " )
189-
190- # Parse request body
191- body = await request .json ()
192- new_req = ChatCompletionRequest .model_validate (body )
193- if new_req .stream :
194- return HTTPException (status_code = 400 , detail = "Streaming responses not supported in current AgentJet version, please set `stream=false` for now." )
195- # Create timeline UUID
196- timeline_uuid = uuid .uuid4 ().hex
197-
198- # Add to received queue
199- # logger.warning(f"Received new chat completion request for agent: {agent_name}, target_tag: {target_tag}, episode_uuid: {episode_uuid}, timeline_uuid: {timeline_uuid}")
200- int_req = InterchangeCompletionRequest (
201- completion_request = new_req ,
202- agent_name = agent_name ,
203- target_tag = target_tag ,
204- episode_uuid = episode_uuid ,
205- timeline_uuid = timeline_uuid ,
206- )
207- logger .info (f"episode_uuid: { episode_uuid } | Received new chat completion request for episode_uuid: { episode_uuid } , timeline_uuid: { timeline_uuid } (outside thread)" )
208- client_offline = asyncio .Event ()
209- try :
210- loop = asyncio .get_running_loop ()
211- return await loop .run_in_executor (request .app .state .executor , _begin_handle_chat_completion , int_req , episode_uuid , timeline_uuid , client_offline )
212- finally :
213- client_offline .set ()
79+ def get_app ():
80+
81+ @asynccontextmanager
82+ async def lifespan (app : FastAPI ):
83+ # Startup
84+ SERVER_SHUTDOWN_EVENT .clear ()
85+ app .state .executor = ThreadPoolExecutor (max_workers = 512 )
86+ yield
87+ # Shutdown
88+ SERVER_SHUTDOWN_EVENT .set ()
89+ app .state .executor .shutdown (wait = False , cancel_futures = True )
90+
91+
92+
93+ app = FastAPI (title = "AJet Interchange Endpoint" , lifespan = lifespan )
21494
21595
96+ def _begin_handle_chat_completion (int_req , episode_uuid , timeline_uuid , client_offline : threading .Event ):
97+ """ run this in thread to avoid blocking main event loop
98+ """
99+ logger .info (f"episode_uuid: { episode_uuid } | Received new chat completion request for episode_uuid: { episode_uuid } , timeline_uuid: { timeline_uuid } (inside thread)" )
216100
101+ redis_client = get_redis_client ()
102+ episode_stream = f"stream:episode:{ episode_uuid } "
103+ timeline_stream = f"stream:timeline:{ timeline_uuid } "
217104
218- @app .post ("/reset" )
219- async def reset ():
105+ max_wait_time = 600 # 10 minutes timeout
106+ try :
107+ logger .info (f"episode_uuid: { episode_uuid } | redis_client.xadd int_req " )
108+ redis_client .xadd (episode_stream , {'data' : pickle .dumps (int_req .model_dump_json ())})
109+ logger .info (f"episode_uuid: { episode_uuid } | redis_client.xadd int_req end" )
110+
111+ # record start
112+ begin_time = time .time ()
113+
114+ # wait for result
115+ last_id = '0-0'
116+ while (not client_offline .is_set ()) and (not SERVER_SHUTDOWN_EVENT .is_set ()):
117+ timepassed = time .time () - begin_time
118+ if timepassed > max_wait_time :
119+ return HTTPException (status_code = 504 , detail = "Request timeout" )
120+ try :
121+ logger .info (f"episode_uuid: { episode_uuid } | redis_client.xread block=30000" )
122+ # Block for 30 seconds to allow loop to check client_offline
123+ response = redis_client .xread ({timeline_stream : last_id }, count = 1 , block = 30 * 1000 ) # block for 30 seconds
124+ logger .info (f"episode_uuid: { episode_uuid } | redis_client.xread after" )
125+
126+
127+ if not response :
128+ if timepassed > 60 :
129+ logger .warning (f"episode_uuid: { episode_uuid } | LLM client infer still waiting... (time passed { timepassed } ) for episode_uuid:{ episode_uuid } , timeline_uuid:{ timeline_uuid } ..." )
130+ continue
131+
132+ # response format: [[stream_name, [[message_id, data_dict]]]]
133+ stream_result = response [0 ] # type: ignore
134+ messages = stream_result [1 ]
135+ message_id , data_dict = messages [0 ]
136+
137+ logger .info (f"episode_uuid: { episode_uuid } | successfully get message from redis stream" )
138+
139+ # Retrieve data, decode_responses=False so keys/values are bytes
140+ if b'data' in data_dict :
141+ data_bytes = data_dict [b'data' ]
142+ else :
143+ logger .error (f"Missing 'data' field in stream message: { data_dict } " )
144+ continue
145+
146+ result_object_str = pickle .loads (data_bytes )
147+
148+ if result_object_str .startswith ('[ERR]' ):
149+ return HTTPException (status_code = 500 , detail = "Error response, " + result_object_str )
150+ result_object = ChatCompletion (** json .loads (result_object_str ))
151+
152+ # Cleanup stream
153+ redis_client .delete (timeline_stream )
154+
155+ return result_object
156+
157+ except TimeoutError :
158+ logger .info (f"episode_uuid: { episode_uuid } | still waiting, (time passed { timepassed } ) for result for episode_uuid:{ episode_uuid } , timeline_uuid:{ timeline_uuid } ..." )
159+ continue
160+ except Exception as e :
161+ logger .error (f"Error reading from stream: { e } " )
162+ if timepassed > max_wait_time :
163+ raise e
164+ time .sleep (1 )
165+
166+ except Exception as e :
167+ logger .error (f"Communication failed: { e } " )
168+ return HTTPException (status_code = 500 , detail = f"Communication failed: { e } " )
169+
170+ finally :
171+ redis_client .close ()
172+
173+
174+ @app .post ("/v1/chat/completions" )
175+ async def chat_completions (request : Request , authorization : str = Header (None )):
176+ """
177+ OpenAI-compatible chat completions endpoint.
178+ Receives ChatCompletionRequest and returns ChatCompletion.
179+ """
180+ # Parse authorization header (base64 encoded JSON)
181+ if not authorization :
182+ return HTTPException (status_code = 401 , detail = "Missing authorization header" )
220183
221- return {"status" : "reset_complete" }
184+ try :
185+ # Remove "Bearer " prefix if present
186+ auth_token = authorization .replace ("Bearer " , "" ).replace ("bearer " , "" )
187+ decoded = base64 .b64decode (auth_token ).decode ('utf-8' )
188+ auth_data = json .loads (decoded )
189+
190+ agent_name = auth_data .get ("agent_name" )
191+ target_tag = auth_data .get ("target_tag" )
192+ episode_uuid = auth_data .get ("episode_uuid" )
193+
194+ if not all ([agent_name , target_tag , episode_uuid ]):
195+ return HTTPException (status_code = 401 , detail = "Invalid authorization data" )
196+ except Exception as e :
197+ return HTTPException (status_code = 401 , detail = f"Invalid authorization header: { str (e )} " )
198+
199+ # Parse request body
200+ body = await request .json ()
201+ new_req = ChatCompletionRequest .model_validate (body )
202+ if new_req .stream :
203+ return HTTPException (status_code = 400 , detail = "Streaming responses not supported in current AgentJet version, please set `stream=false` for now." )
204+ # Create timeline UUID
205+ timeline_uuid = uuid .uuid4 ().hex
206+
207+ # Add to received queue
208+ # logger.warning(f"Received new chat completion request for agent: {agent_name}, target_tag: {target_tag}, episode_uuid: {episode_uuid}, timeline_uuid: {timeline_uuid}")
209+ int_req = InterchangeCompletionRequest (
210+ completion_request = new_req ,
211+ agent_name = agent_name ,
212+ target_tag = target_tag ,
213+ episode_uuid = episode_uuid ,
214+ timeline_uuid = timeline_uuid ,
215+ )
216+ logger .info (f"episode_uuid: { episode_uuid } | Received new chat completion request for episode_uuid: { episode_uuid } , timeline_uuid: { timeline_uuid } (outside thread)" )
217+ client_offline = threading .Event ()
218+ try :
219+ loop = asyncio .get_running_loop ()
220+ return await loop .run_in_executor (request .app .state .executor , _begin_handle_chat_completion , int_req , episode_uuid , timeline_uuid , client_offline )
221+ finally :
222+ client_offline .set ()
222223
223224
224- async def monitor_debug_state (experiment_dir ):
225- """
226- Background task to write debug state to ./interchange_debug.txt every 1 second.
227- """
228- while True :
229- await asyncio .sleep (4 )
230225
231226
232- def ensure_dat_interchange_server_cache_clear ():
233- return
227+ @ app . post ( "/reset" )
228+ async def reset ():
234229
230+ return {"status" : "reset_complete" }
235231
236232
233+ return app
234+
237235class InterchangeServer (Process ):
238236 def __init__ (self , experiment_dir : str , port : int ):
239237 super ().__init__ ()
240238 self .experiment_dir = experiment_dir
241239 self .port = port
242240
243241 def run (self ):
242+ app = get_app ()
244243 async def serve_with_monitor ():
245- # Start the monitor task
246- asyncio .create_task (monitor_debug_state (self .experiment_dir ))
247244 # Start the server
248245 config = uvicorn .Config (
249246 app = app ,
@@ -254,8 +251,11 @@ async def serve_with_monitor():
254251 )
255252 server = uvicorn .Server (config )
256253 await server .serve ()
257-
258- asyncio .run (serve_with_monitor ())
254+ try :
255+ asyncio .run (serve_with_monitor ())
256+ except KeyboardInterrupt as e :
257+ SERVER_SHUTDOWN_EVENT .set ()
258+ raise e
259259
260260
261261# Convenience function for quick server startup
@@ -270,7 +270,6 @@ def start_interchange_server(experiment_dir) -> int:
270270 os .environ ["AJET_DAT_INTERCHANGE_PORT" ] = str (port )
271271
272272 interchange_server = InterchangeServer (experiment_dir , port )
273- interchange_server .daemon = True
274273 interchange_server .start ()
275274
276275 # Wait for server to be ready
@@ -288,6 +287,7 @@ def start_interchange_server(experiment_dir) -> int:
288287 time .sleep (0.5 )
289288
290289 logger .info (f"Interchange server subprocess started on port { port } (pid: { interchange_server .pid } )" )
290+ atexit .register (lambda : interchange_server .terminate ())
291291 return port
292292
293293
0 commit comments