Skip to content

Commit 4a151c4

Browse files
committed
Enhance interchange server functionality with experiment directory support
1 parent 6866109 commit 4a151c4

File tree

4 files changed

+34
-19
lines changed

4 files changed

+34
-19
lines changed

ajet/backbone/main_trinity.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
import ray
2+
import os
23
from trinity.cli.launcher import main
34
from trinity.common.config import Config
45
from trinity.explorer.explorer import Explorer
56
from trinity.trainer.trainer import Trainer
67

8+
from ajet.utils.config_utils import read_ajet_config_with_cache
79
from ajet.utils.core_env_vars import get_runtime_env
810
from ajet.utils.launch_utils import set_loguru_default_color
911

12+
1013
set_loguru_default_color()
1114

1215

16+
def get_ajet_config_from_trinity_side():
17+
yaml_path = os.environ.get("AJET_CONFIG_REDIRECT", None)
18+
if yaml_path is None:
19+
raise ValueError("AJET_CONFIG_REDIRECT is not set in environment variables")
20+
ajet_config = read_ajet_config_with_cache(yaml_path)
21+
return ajet_config
22+
23+
1324
def patch_runtime_env_to_get_actor():
1425
"""Patch the classmethod of Explorer and Trainer to pass in the runtime env."""
1526
runtime_env = get_runtime_env(is_trinity=True)
@@ -39,7 +50,10 @@ def patched_trainer_get_actor(cls, config: Config):
3950
Explorer.get_actor = classmethod(patched_explorer_get_actor)
4051
Trainer.get_actor = classmethod(patched_trainer_get_actor)
4152

42-
53+
ajet_config = get_ajet_config_from_trinity_side()
54+
if ajet_config.ajet.enable_experimental_reverse_proxy:
55+
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
56+
start_interchange_server(ajet_config.ajet.experiment_dir)
4357

4458

4559
if __name__ == "__main__":

ajet/backbone/main_verl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def run(self, config):
249249

250250
if config.ajet.enable_experimental_reverse_proxy:
251251
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
252-
start_interchange_server()
252+
start_interchange_server(config.ajet.experiment_dir)
253253

254254
# Initialize the PPO trainer.
255255
trainer = AjetRayPPOTrainer(

ajet/backbone/main_vllm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,10 @@ def main(config):
137137

138138
runtime_env = get_runtime_env()
139139
os.environ.update(runtime_env["env_vars"])
140+
140141
if config.ajet.enable_experimental_reverse_proxy:
141142
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
142-
start_interchange_server()
143+
start_interchange_server(config.ajet.experiment_dir)
143144

144145
def companion_launch():
145146
import torch

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def coro_task_1_lookup_dict_received__send_loop(key, websocket: WebSocket,
5656
while not stop_event.is_set():
5757
# Check for new requests in ajet_remote_handler_received
5858
if (key in ajet_remote_handler_received) and len(ajet_remote_handler_received[key]) > 0:
59-
logger.warning(f"Sending new request to client for key: {key}")
59+
# logger.warning(f"Sending new request to client for key: {key}")
6060

6161
timeline_uuid = list(ajet_remote_handler_received[key].keys())[0]
6262

@@ -95,9 +95,9 @@ async def coro_task_2_lookup_dict_received__receive_loop(key, websocket: WebSock
9595
# Wait for client response:
9696
# ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
9797
# await websocket.send(pickle.dumps(response))
98-
logger.warning(f"Waiting for response from client for key: {key}")
98+
# logger.warning(f"Waiting for response from client for key: {key}")
9999
response_data = pickle.loads(await websocket.receive_bytes())
100-
logger.warning(f"Received response from client for key: {key}")
100+
# logger.warning(f"Received response from client for key: {key}")
101101

102102
if not isinstance(response_data, ChatCompletion):
103103
stop_event.set()
@@ -152,7 +152,7 @@ async def context_tracker_client_listen(websocket: WebSocket):
152152
assert episode_uuid_str.startswith("episode_uuid:")
153153
episode_uuid = episode_uuid_str.split("episode_uuid:")[-1]
154154

155-
logger.warning(f"WebSocket client connected for episode_uuid: {episode_uuid}")
155+
# logger.warning(f"WebSocket client connected for episode_uuid: {episode_uuid}")
156156

157157
key = f"episode_uuid:{episode_uuid}"
158158
active_websockets[key] = websocket
@@ -166,7 +166,7 @@ async def context_tracker_client_listen(websocket: WebSocket):
166166
logger.exception(f"Error in websocket connection setup: {e}")
167167

168168
finally:
169-
logger.warning(f"WebSocket client disconnected for key: {key}")
169+
# logger.warning(f"WebSocket client disconnected for key: {key}")
170170
if key:
171171
# Clean up any in-progress requests for this key
172172
for container in [
@@ -220,7 +220,7 @@ async def chat_completions(request: Request, authorization: str = Header(None)):
220220
# Create timeline UUID
221221
timeline_uuid = uuid.uuid4().hex
222222
# Add to received queue
223-
logger.warning(f"Received new chat completion request for agent: {agent_name}, target_tag: {target_tag}, episode_uuid: {episode_uuid}, timeline_uuid: {timeline_uuid}")
223+
# logger.warning(f"Received new chat completion request for agent: {agent_name}, target_tag: {target_tag}, episode_uuid: {episode_uuid}, timeline_uuid: {timeline_uuid}")
224224
ajet_remote_handler_received[key][timeline_uuid] = InterchangeCompletionRequest(
225225
completion_request = new_req,
226226
agent_name = agent_name,
@@ -259,7 +259,7 @@ async def reset():
259259
"""
260260
Reset endpoint to clear all state and disconnect all websockets.
261261
"""
262-
logger.warning("Resetting interchange endpoint server state.")
262+
# logger.warning("Resetting interchange endpoint server state.")
263263
# Disconnect all websockets
264264
for key, ws in list(active_websockets.items()):
265265
try:
@@ -278,7 +278,7 @@ async def reset():
278278
return {"status": "reset_complete"}
279279

280280

281-
async def monitor_debug_state():
281+
async def monitor_debug_state(experiment_dir):
282282
"""
283283
Background task to write debug state to ./interchange_debug.txt every 1 second.
284284
"""
@@ -292,14 +292,14 @@ async def monitor_debug_state():
292292
'active_websockets': list(active_websockets.keys())
293293
}
294294

295-
with open('./interchange_debug.txt', 'w') as f:
295+
with open(f'{experiment_dir}/interchange_debug.txt', 'w') as f:
296296
f.write(pformat(debug_info, width=120, indent=2))
297297
f.write('\n')
298298

299-
await asyncio.sleep(1)
299+
await asyncio.sleep(2)
300300
except Exception as e:
301301
logger.error(f"Error in monitor_debug_state: {e}")
302-
await asyncio.sleep(1)
302+
await asyncio.sleep(2)
303303

304304

305305
def ensure_dat_interchange_server_cache_clear():
@@ -332,7 +332,7 @@ def __init__(self):
332332
self.server_thread = None
333333
self.server = None
334334

335-
def start(self) -> int:
335+
def start(self, experiment_dir) -> int:
336336
"""
337337
Start the FastAPI server on a free port.
338338
@@ -346,14 +346,14 @@ def start(self) -> int:
346346
def run_server():
347347
async def serve_with_monitor():
348348
# Start the monitor task
349-
monitor_task = asyncio.create_task(monitor_debug_state())
349+
monitor_task = asyncio.create_task(monitor_debug_state(experiment_dir))
350350

351351
# Start the server
352352
config = uvicorn.Config(
353353
app=app,
354354
host="0.0.0.0",
355355
port=self.port,
356-
log_level="info"
356+
log_level="error"
357357
)
358358
server = uvicorn.Server(config)
359359
await server.serve()
@@ -374,14 +374,14 @@ def stop(self):
374374

375375

376376
# Convenience function for quick server startup
377-
def start_interchange_server() -> int:
377+
def start_interchange_server(experiment_dir) -> int:
378378
"""
379379
Start the interchange endpoint server and return the port number.
380380
381381
Returns:
382382
int: The port number the server is running on.
383383
"""
384384
server = InterchangeEndpointServer()
385-
port = server.start()
385+
port = server.start(experiment_dir)
386386
return port
387387

0 commit comments

Comments
 (0)