Skip to content

Commit 1da86e2

Browse files
committed
fix ctrl + c exit problem
1 parent fc0a680 commit 1da86e2

File tree

3 files changed

+155
-336
lines changed

3 files changed

+155
-336
lines changed

ajet/backbone/main_vllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import atexit
12
import os
23
import sys
34
from types import SimpleNamespace
@@ -178,11 +179,10 @@ def run(config):
178179
)
179180
def main(config):
180181
from omegaconf import OmegaConf
181-
182182
OmegaConf.resolve(config)
183-
184183
runtime_env = get_runtime_env()
185184
os.environ.update(runtime_env["env_vars"])
185+
# atexit.register(lambda: print("Process exiting, performing cleanup..."))
186186

187187
if config.ajet.enable_experimental_reverse_proxy:
188188
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py

Lines changed: 153 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import asyncio
1313
from functools import cache
14+
from contextlib import asynccontextmanager
1415
from multiprocessing import Process
1516
import threading
1617
from concurrent.futures import ThreadPoolExecutor
@@ -72,178 +73,174 @@ def get_redis_client():
7273
pool = get_redis_connection_pool()
7374
return redis.Redis(connection_pool=pool, decode_responses=False, encoding='utf-8')
7475

75-
7676
# Create FastAPI app
77-
app = FastAPI(title="AJet Interchange Endpoint")
78-
79-
@app.on_event("startup")
80-
async def startup_event():
81-
app.state.executor = ThreadPoolExecutor(max_workers=512)
82-
83-
@app.on_event("shutdown")
84-
async def shutdown_event():
85-
app.state.executor.shutdown()
86-
87-
88-
def _begin_handle_chat_completion(int_req, episode_uuid, timeline_uuid, client_offline: asyncio.Event):
89-
""" run this in thread to avoid blocking main event loop
90-
"""
91-
logger.info(f"episode_uuid: {episode_uuid} | Received new chat completion request for episode_uuid: {episode_uuid}, timeline_uuid: {timeline_uuid} (inside thread)")
92-
93-
redis_client = get_redis_client()
94-
episode_stream = f"stream:episode:{episode_uuid}"
95-
timeline_stream = f"stream:timeline:{timeline_uuid}"
96-
97-
max_wait_time = 600 # 10 minutes timeout
98-
try:
99-
logger.info(f"episode_uuid: {episode_uuid} | redis_client.xadd int_req ")
100-
redis_client.xadd(episode_stream, {'data': pickle.dumps(int_req.model_dump_json())})
101-
logger.info(f"episode_uuid: {episode_uuid} | redis_client.xadd int_req end")
102-
103-
# record start
104-
begin_time = time.time()
105-
106-
# wait for result
107-
last_id = '0-0'
108-
while not client_offline.is_set():
109-
timepassed = time.time() - begin_time
110-
if timepassed > max_wait_time:
111-
return HTTPException(status_code=504, detail="Request timeout")
112-
try:
113-
logger.info(f"episode_uuid: {episode_uuid} | redis_client.xread block=30000")
114-
# Block for 30 seconds to allow loop to check client_offline
115-
response = redis_client.xread({timeline_stream: last_id}, count=1, block=30*1000) # block for 30 seconds
116-
logger.info(f"episode_uuid: {episode_uuid} | redis_client.xread after")
117-
118-
if not response:
119-
if timepassed > 60:
120-
logger.warning(f"episode_uuid: {episode_uuid} | LLM client infer still waiting... (time passed {timepassed}) for episode_uuid:{episode_uuid}, timeline_uuid:{timeline_uuid}...")
121-
continue
77+
SERVER_SHUTDOWN_EVENT = threading.Event()
12278

123-
# response format: [[stream_name, [[message_id, data_dict]]]]
124-
stream_result = response[0]
125-
messages = stream_result[1]
126-
message_id, data_dict = messages[0]
127-
128-
logger.info(f"episode_uuid: {episode_uuid} | successfully get message from redis stream")
129-
130-
# Retrieve data, decode_responses=False so keys/values are bytes
131-
if b'data' in data_dict:
132-
data_bytes = data_dict[b'data']
133-
else:
134-
logger.error(f"Missing 'data' field in stream message: {data_dict}")
135-
continue
136-
137-
result_object_str = pickle.loads(data_bytes)
138-
139-
if result_object_str.startswith('[ERR]'):
140-
return HTTPException(status_code=500, detail="Error response, " + result_object_str)
141-
result_object = ChatCompletion(**json.loads(result_object_str))
142-
143-
# Cleanup stream
144-
redis_client.delete(timeline_stream)
145-
146-
return result_object
147-
148-
except TimeoutError:
149-
logger.info(f"episode_uuid: {episode_uuid} | still waiting, (time passed {timepassed}) for result for episode_uuid:{episode_uuid}, timeline_uuid:{timeline_uuid}...")
150-
continue
151-
except Exception as e:
152-
logger.error(f"Error reading from stream: {e}")
153-
if timepassed > max_wait_time:
154-
raise e
155-
time.sleep(1)
156-
157-
except Exception as e:
158-
logger.error(f"Communication failed: {e}")
159-
return HTTPException(status_code=500, detail=f"Communication failed: {e}")
160-
161-
finally:
162-
redis_client.close()
163-
164-
165-
@app.post("/v1/chat/completions")
166-
async def chat_completions(request: Request, authorization: str = Header(None)):
167-
"""
168-
OpenAI-compatible chat completions endpoint.
169-
Receives ChatCompletionRequest and returns ChatCompletion.
170-
"""
171-
# Parse authorization header (base64 encoded JSON)
172-
if not authorization:
173-
return HTTPException(status_code=401, detail="Missing authorization header")
174-
175-
try:
176-
# Remove "Bearer " prefix if present
177-
auth_token = authorization.replace("Bearer ", "").replace("bearer ", "")
178-
decoded = base64.b64decode(auth_token).decode('utf-8')
179-
auth_data = json.loads(decoded)
180-
181-
agent_name = auth_data.get("agent_name")
182-
target_tag = auth_data.get("target_tag")
183-
episode_uuid = auth_data.get("episode_uuid")
184-
185-
if not all([agent_name, target_tag, episode_uuid]):
186-
return HTTPException(status_code=401, detail="Invalid authorization data")
187-
except Exception as e:
188-
return HTTPException(status_code=401, detail=f"Invalid authorization header: {str(e)}")
189-
190-
# Parse request body
191-
body = await request.json()
192-
new_req = ChatCompletionRequest.model_validate(body)
193-
if new_req.stream:
194-
return HTTPException(status_code=400, detail="Streaming responses not supported in current AgentJet version, please set `stream=false` for now.")
195-
# Create timeline UUID
196-
timeline_uuid = uuid.uuid4().hex
197-
198-
# Add to received queue
199-
# logger.warning(f"Received new chat completion request for agent: {agent_name}, target_tag: {target_tag}, episode_uuid: {episode_uuid}, timeline_uuid: {timeline_uuid}")
200-
int_req = InterchangeCompletionRequest(
201-
completion_request = new_req,
202-
agent_name = agent_name,
203-
target_tag = target_tag,
204-
episode_uuid = episode_uuid,
205-
timeline_uuid = timeline_uuid,
206-
)
207-
logger.info(f"episode_uuid: {episode_uuid} | Received new chat completion request for episode_uuid: {episode_uuid}, timeline_uuid: {timeline_uuid} (outside thread)")
208-
client_offline = asyncio.Event()
209-
try:
210-
loop = asyncio.get_running_loop()
211-
return await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, int_req, episode_uuid, timeline_uuid, client_offline)
212-
finally:
213-
client_offline.set()
79+
def get_app():
80+
81+
@asynccontextmanager
82+
async def lifespan(app: FastAPI):
83+
# Startup
84+
SERVER_SHUTDOWN_EVENT.clear()
85+
app.state.executor = ThreadPoolExecutor(max_workers=512)
86+
yield
87+
# Shutdown
88+
SERVER_SHUTDOWN_EVENT.set()
89+
app.state.executor.shutdown(wait=False, cancel_futures=True)
90+
91+
92+
93+
app = FastAPI(title="AJet Interchange Endpoint", lifespan=lifespan)
21494

21595

96+
def _begin_handle_chat_completion(int_req, episode_uuid, timeline_uuid, client_offline: threading.Event):
97+
""" run this in thread to avoid blocking main event loop
98+
"""
99+
logger.info(f"episode_uuid: {episode_uuid} | Received new chat completion request for episode_uuid: {episode_uuid}, timeline_uuid: {timeline_uuid} (inside thread)")
216100

101+
redis_client = get_redis_client()
102+
episode_stream = f"stream:episode:{episode_uuid}"
103+
timeline_stream = f"stream:timeline:{timeline_uuid}"
217104

218-
@app.post("/reset")
219-
async def reset():
105+
max_wait_time = 600 # 10 minutes timeout
106+
try:
107+
logger.info(f"episode_uuid: {episode_uuid} | redis_client.xadd int_req ")
108+
redis_client.xadd(episode_stream, {'data': pickle.dumps(int_req.model_dump_json())})
109+
logger.info(f"episode_uuid: {episode_uuid} | redis_client.xadd int_req end")
110+
111+
# record start
112+
begin_time = time.time()
113+
114+
# wait for result
115+
last_id = '0-0'
116+
while (not client_offline.is_set()) and (not SERVER_SHUTDOWN_EVENT.is_set()):
117+
timepassed = time.time() - begin_time
118+
if timepassed > max_wait_time:
119+
return HTTPException(status_code=504, detail="Request timeout")
120+
try:
121+
logger.info(f"episode_uuid: {episode_uuid} | redis_client.xread block=30000")
122+
# Block for 30 seconds to allow loop to check client_offline
123+
response = redis_client.xread({timeline_stream: last_id}, count=1, block=30*1000) # block for 30 seconds
124+
logger.info(f"episode_uuid: {episode_uuid} | redis_client.xread after")
125+
126+
127+
if not response:
128+
if timepassed > 60:
129+
logger.warning(f"episode_uuid: {episode_uuid} | LLM client infer still waiting... (time passed {timepassed}) for episode_uuid:{episode_uuid}, timeline_uuid:{timeline_uuid}...")
130+
continue
131+
132+
# response format: [[stream_name, [[message_id, data_dict]]]]
133+
stream_result = response[0] # type: ignore
134+
messages = stream_result[1]
135+
message_id, data_dict = messages[0]
136+
137+
logger.info(f"episode_uuid: {episode_uuid} | successfully get message from redis stream")
138+
139+
# Retrieve data, decode_responses=False so keys/values are bytes
140+
if b'data' in data_dict:
141+
data_bytes = data_dict[b'data']
142+
else:
143+
logger.error(f"Missing 'data' field in stream message: {data_dict}")
144+
continue
145+
146+
result_object_str = pickle.loads(data_bytes)
147+
148+
if result_object_str.startswith('[ERR]'):
149+
return HTTPException(status_code=500, detail="Error response, " + result_object_str)
150+
result_object = ChatCompletion(**json.loads(result_object_str))
151+
152+
# Cleanup stream
153+
redis_client.delete(timeline_stream)
154+
155+
return result_object
156+
157+
except TimeoutError:
158+
logger.info(f"episode_uuid: {episode_uuid} | still waiting, (time passed {timepassed}) for result for episode_uuid:{episode_uuid}, timeline_uuid:{timeline_uuid}...")
159+
continue
160+
except Exception as e:
161+
logger.error(f"Error reading from stream: {e}")
162+
if timepassed > max_wait_time:
163+
raise e
164+
time.sleep(1)
165+
166+
except Exception as e:
167+
logger.error(f"Communication failed: {e}")
168+
return HTTPException(status_code=500, detail=f"Communication failed: {e}")
169+
170+
finally:
171+
redis_client.close()
172+
173+
174+
@app.post("/v1/chat/completions")
175+
async def chat_completions(request: Request, authorization: str = Header(None)):
176+
"""
177+
OpenAI-compatible chat completions endpoint.
178+
Receives ChatCompletionRequest and returns ChatCompletion.
179+
"""
180+
# Parse authorization header (base64 encoded JSON)
181+
if not authorization:
182+
return HTTPException(status_code=401, detail="Missing authorization header")
220183

221-
return {"status": "reset_complete"}
184+
try:
185+
# Remove "Bearer " prefix if present
186+
auth_token = authorization.replace("Bearer ", "").replace("bearer ", "")
187+
decoded = base64.b64decode(auth_token).decode('utf-8')
188+
auth_data = json.loads(decoded)
189+
190+
agent_name = auth_data.get("agent_name")
191+
target_tag = auth_data.get("target_tag")
192+
episode_uuid = auth_data.get("episode_uuid")
193+
194+
if not all([agent_name, target_tag, episode_uuid]):
195+
return HTTPException(status_code=401, detail="Invalid authorization data")
196+
except Exception as e:
197+
return HTTPException(status_code=401, detail=f"Invalid authorization header: {str(e)}")
198+
199+
# Parse request body
200+
body = await request.json()
201+
new_req = ChatCompletionRequest.model_validate(body)
202+
if new_req.stream:
203+
return HTTPException(status_code=400, detail="Streaming responses not supported in current AgentJet version, please set `stream=false` for now.")
204+
# Create timeline UUID
205+
timeline_uuid = uuid.uuid4().hex
206+
207+
# Add to received queue
208+
# logger.warning(f"Received new chat completion request for agent: {agent_name}, target_tag: {target_tag}, episode_uuid: {episode_uuid}, timeline_uuid: {timeline_uuid}")
209+
int_req = InterchangeCompletionRequest(
210+
completion_request = new_req,
211+
agent_name = agent_name,
212+
target_tag = target_tag,
213+
episode_uuid = episode_uuid,
214+
timeline_uuid = timeline_uuid,
215+
)
216+
logger.info(f"episode_uuid: {episode_uuid} | Received new chat completion request for episode_uuid: {episode_uuid}, timeline_uuid: {timeline_uuid} (outside thread)")
217+
client_offline = threading.Event()
218+
try:
219+
loop = asyncio.get_running_loop()
220+
return await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, int_req, episode_uuid, timeline_uuid, client_offline)
221+
finally:
222+
client_offline.set()
222223

223224

224-
async def monitor_debug_state(experiment_dir):
225-
"""
226-
Background task to write debug state to ./interchange_debug.txt every 1 second.
227-
"""
228-
while True:
229-
await asyncio.sleep(4)
230225

231226

232-
def ensure_dat_interchange_server_cache_clear():
233-
return
227+
@app.post("/reset")
228+
async def reset():
234229

230+
return {"status": "reset_complete"}
235231

236232

233+
return app
234+
237235
class InterchangeServer(Process):
238236
def __init__(self, experiment_dir: str, port: int):
239237
super().__init__()
240238
self.experiment_dir = experiment_dir
241239
self.port = port
242240

243241
def run(self):
242+
app = get_app()
244243
async def serve_with_monitor():
245-
# Start the monitor task
246-
asyncio.create_task(monitor_debug_state(self.experiment_dir))
247244
# Start the server
248245
config = uvicorn.Config(
249246
app=app,
@@ -254,8 +251,11 @@ async def serve_with_monitor():
254251
)
255252
server = uvicorn.Server(config)
256253
await server.serve()
257-
258-
asyncio.run(serve_with_monitor())
254+
try:
255+
asyncio.run(serve_with_monitor())
256+
except KeyboardInterrupt as e:
257+
SERVER_SHUTDOWN_EVENT.set()
258+
raise e
259259

260260

261261
# Convenience function for quick server startup
@@ -270,7 +270,6 @@ def start_interchange_server(experiment_dir) -> int:
270270
os.environ["AJET_DAT_INTERCHANGE_PORT"] = str(port)
271271

272272
interchange_server = InterchangeServer(experiment_dir, port)
273-
interchange_server.daemon = True
274273
interchange_server.start()
275274

276275
# Wait for server to be ready
@@ -288,6 +287,7 @@ def start_interchange_server(experiment_dir) -> int:
288287
time.sleep(0.5)
289288

290289
logger.info(f"Interchange server subprocess started on port {port} (pid: {interchange_server.pid})")
290+
atexit.register(lambda: interchange_server.terminate())
291291
return port
292292

293293

0 commit comments

Comments
 (0)