|
24 | 24 | from pydantic import BaseModel, ConfigDict, model_validator |
25 | 25 | from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Header, HTTPException, Request |
26 | 26 | import uvicorn |
| 27 | +import sys |
| 28 | +import subprocess |
| 29 | +import atexit |
| 30 | +import argparse |
27 | 31 |
|
28 | 32 | from vllm.entrypoints.openai.protocol import ChatCompletionRequest |
29 | 33 | from openai.types.chat.chat_completion import ChatCompletion |
@@ -83,11 +87,11 @@ async def coro_task_1_lookup_dict_received__send_loop(key, websocket: WebSocket, |
83 | 87 | # will be received by: |
84 | 88 | # ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py |
85 | 89 | # await asyncio.wait_for(websocket.recv(decode=False), timeout=0.25) |
86 | | - try: |
87 | | - await websocket.send_bytes(pickle.dumps(new_req)) |
88 | | - except: |
89 | | - # AgentScope sometimes fails the standard OAI schema compliance check for ChatCompletionRequest |
90 | | - await websocket.send_bytes(pickle.dumps(new_req.model_dump_json())) |
| 90 | + # try: |
| 91 | + # await websocket.send_bytes(pickle.dumps(new_req)) |
| 92 | + # except: |
| 93 | + # # AgentScope sometimes fails the standard OAI schema compliance check for ChatCompletionRequest |
| 94 | + await websocket.send_bytes(pickle.dumps(new_req.model_dump_json())) |
91 | 95 | else: |
92 | 96 | await asyncio.sleep(POLL_INTERVAL_SECONDS) |
93 | 97 |
|
@@ -407,11 +411,76 @@ def stop(self): |
407 | 411 | def start_interchange_server(experiment_dir) -> int: |
408 | 412 | """ |
409 | 413 | Start the interchange endpoint server and return the port number. |
| 414 | + This launches a subprocess to run the server. |
410 | 415 |
|
411 | 416 | Returns: |
412 | 417 | int: The port number the server is running on. |
413 | 418 | """ |
414 | | - server = InterchangeEndpointServer() |
415 | | - port = server.start(experiment_dir) |
| 419 | + # Find a free port if not specified or invalid |
| 420 | + port = int(os.environ.get("AJET_DAT_INTERCHANGE_PORT", -1)) |
| 421 | + if port <= 0: |
| 422 | + import socket |
| 423 | + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| 424 | + s.bind(("", 0)) |
| 425 | + port = s.getsockname()[1] |
| 426 | + os.environ["AJET_DAT_INTERCHANGE_PORT"] = str(port) |
| 427 | + |
| 428 | + # Launch as subprocess |
| 429 | + env = os.environ.copy() |
| 430 | + |
| 431 | + # We run this file as a script |
| 432 | + cmd = [sys.executable, os.path.abspath(__file__), "--experiment_dir", experiment_dir, "--port", str(port)] |
| 433 | + |
| 434 | + process = subprocess.Popen( |
| 435 | + cmd, |
| 436 | + env=env, |
| 437 | + # redirect stdout/stderr if needed, but keeping them might be useful for debug |
| 438 | + # stdout=subprocess.DEVNULL, |
| 439 | + # stderr=subprocess.DEVNULL |
| 440 | + ) |
| 441 | + |
| 442 | + def cleanup(): |
| 443 | + if process.poll() is None: |
| 444 | + logger.info("Terminating interchange server subprocess") |
| 445 | + process.terminate() |
| 446 | + try: |
| 447 | + process.wait(timeout=2) |
| 448 | + except subprocess.TimeoutExpired: |
| 449 | + process.kill() |
| 450 | + |
| 451 | + atexit.register(cleanup) |
| 452 | + |
| 453 | + logger.info(f"Interchange server subprocess started on port {port} (pid: {process.pid})") |
416 | 454 | return port |
417 | 455 |
|
| 456 | + |
| 457 | +if __name__ == "__main__": |
| 458 | + parser = argparse.ArgumentParser(description="AJet Interchange Endpoint Server") |
| 459 | + parser.add_argument("--experiment_dir", type=str, required=True, help="Directory to store debug info") |
| 460 | + parser.add_argument("--port", type=int, required=True, help="Port to run the server on") |
| 461 | + |
| 462 | + args = parser.parse_args() |
| 463 | + |
| 464 | + async def serve_with_monitor(): |
| 465 | + # Start the monitor task |
| 466 | + asyncio.create_task(monitor_debug_state(args.experiment_dir)) |
| 467 | + |
| 468 | + # Start the server |
| 469 | + config = uvicorn.Config( |
| 470 | + app=app, |
| 471 | + host="0.0.0.0", |
| 472 | + port=args.port, |
| 473 | + log_level="error", |
| 474 | + ws_max_queue=1024, |
| 475 | + ws_ping_interval=60, |
| 476 | + ws_ping_timeout=60, |
| 477 | + ws_per_message_deflate=True, |
| 478 | + ) |
| 479 | + server = uvicorn.Server(config) |
| 480 | + await server.serve() |
| 481 | + |
| 482 | + try: |
| 483 | + asyncio.run(serve_with_monitor()) |
| 484 | + except KeyboardInterrupt: |
| 485 | + pass |
| 486 | + |
0 commit comments