Skip to content

Commit 78ec41d

Browse files
committed
interchange threading -> interchange subprocess
1 parent f989472 commit 78ec41d

File tree

2 files changed

+79
-11
lines changed

2 files changed

+79
-11
lines changed

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def begin_service(self):
9898
"""
9999
Starts the websocket service loop.
100100
"""
101-
t = threading.Thread(target=lambda: asyncio.run(self._ensure_service_loop()))
101+
t = threading.Thread(target=lambda: asyncio.run(self._ensure_service_loop()), daemon=True)
102102
t.start()
103103

104104
async def _ensure_service_loop(self):
@@ -133,11 +133,10 @@ async def _service_loop(self):
133133

134134
try:
135135
# wait message from ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py
136-
parsed_msg: InterchangeCompletionRequest = pickle.loads(
136+
parsed_msg_str: str = pickle.loads(
137137
await asyncio.wait_for(websocket.recv(decode=False), timeout=0.25)
138138
)
139-
if isinstance(parsed_msg, str):
140-
parsed_msg = InterchangeCompletionRequest(**json.loads(parsed_msg))
139+
parsed_msg:InterchangeCompletionRequest = InterchangeCompletionRequest(**json.loads(parsed_msg_str))
141140

142141
response = await self.llm_infer(
143142
req=parsed_msg.completion_request,

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from pydantic import BaseModel, ConfigDict, model_validator
2525
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Header, HTTPException, Request
2626
import uvicorn
27+
import sys
28+
import subprocess
29+
import atexit
30+
import argparse
2731

2832
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
2933
from openai.types.chat.chat_completion import ChatCompletion
@@ -83,11 +87,11 @@ async def coro_task_1_lookup_dict_received__send_loop(key, websocket: WebSocket,
8387
# will be received by:
8488
# ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
8589
# await asyncio.wait_for(websocket.recv(decode=False), timeout=0.25)
86-
try:
87-
await websocket.send_bytes(pickle.dumps(new_req))
88-
except:
89-
# AgentScope sometimes fails the standard OAI schema compliance check for ChatCompletionRequest
90-
await websocket.send_bytes(pickle.dumps(new_req.model_dump_json()))
90+
# try:
91+
# await websocket.send_bytes(pickle.dumps(new_req))
92+
# except:
93+
# # AgentScope sometimes fails the standard OAI schema compliance check for ChatCompletionRequest
94+
await websocket.send_bytes(pickle.dumps(new_req.model_dump_json()))
9195
else:
9296
await asyncio.sleep(POLL_INTERVAL_SECONDS)
9397

@@ -407,11 +411,76 @@ def stop(self):
407411
def start_interchange_server(experiment_dir) -> int:
408412
"""
409413
Start the interchange endpoint server and return the port number.
414+
This launches a subprocess to run the server.
410415
411416
Returns:
412417
int: The port number the server is running on.
413418
"""
414-
server = InterchangeEndpointServer()
415-
port = server.start(experiment_dir)
419+
# Find a free port if not specified or invalid
420+
port = int(os.environ.get("AJET_DAT_INTERCHANGE_PORT", -1))
421+
if port <= 0:
422+
import socket
423+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
424+
s.bind(("", 0))
425+
port = s.getsockname()[1]
426+
os.environ["AJET_DAT_INTERCHANGE_PORT"] = str(port)
427+
428+
# Launch as subprocess
429+
env = os.environ.copy()
430+
431+
# We run this file as a script
432+
cmd = [sys.executable, os.path.abspath(__file__), "--experiment_dir", experiment_dir, "--port", str(port)]
433+
434+
process = subprocess.Popen(
435+
cmd,
436+
env=env,
437+
# redirect stdout/stderr if needed, but keeping them might be useful for debug
438+
# stdout=subprocess.DEVNULL,
439+
# stderr=subprocess.DEVNULL
440+
)
441+
442+
def cleanup():
443+
if process.poll() is None:
444+
logger.info("Terminating interchange server subprocess")
445+
process.terminate()
446+
try:
447+
process.wait(timeout=2)
448+
except subprocess.TimeoutExpired:
449+
process.kill()
450+
451+
atexit.register(cleanup)
452+
453+
logger.info(f"Interchange server subprocess started on port {port} (pid: {process.pid})")
416454
return port
417455

456+
457+
if __name__ == "__main__":
458+
parser = argparse.ArgumentParser(description="AJet Interchange Endpoint Server")
459+
parser.add_argument("--experiment_dir", type=str, required=True, help="Directory to store debug info")
460+
parser.add_argument("--port", type=int, required=True, help="Port to run the server on")
461+
462+
args = parser.parse_args()
463+
464+
async def serve_with_monitor():
465+
# Start the monitor task
466+
asyncio.create_task(monitor_debug_state(args.experiment_dir))
467+
468+
# Start the server
469+
config = uvicorn.Config(
470+
app=app,
471+
host="0.0.0.0",
472+
port=args.port,
473+
log_level="error",
474+
ws_max_queue=1024,
475+
ws_ping_interval=60,
476+
ws_ping_timeout=60,
477+
ws_per_message_deflate=True,
478+
)
479+
server = uvicorn.Server(config)
480+
await server.serve()
481+
482+
try:
483+
asyncio.run(serve_with_monitor())
484+
except KeyboardInterrupt:
485+
pass
486+

0 commit comments

Comments
 (0)