Skip to content

Commit 0782173

Browse files
committed
keyurl multinode training support
1 parent d0b6cae commit 0782173

17 files changed

Lines changed: 188 additions & 264 deletions

ajet/backbone/main_trinity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def get_ajet_config_from_trinity_side():
2323

2424
def patch_runtime_env_to_get_actor():
2525
"""Patch the classmethod of Explorer and Trainer to pass in the runtime env."""
26-
runtime_env = get_runtime_env(is_trinity=True)
26+
ajet_config = get_ajet_config_from_trinity_side()
27+
runtime_env = get_runtime_env(ajet_config, is_trinity=True)
2728

2829
def patched_explorer_get_actor(cls, config: Config):
2930
return (
@@ -50,7 +51,6 @@ def patched_trainer_get_actor(cls, config: Config):
5051
Explorer.get_actor = classmethod(patched_explorer_get_actor)
5152
Trainer.get_actor = classmethod(patched_trainer_get_actor)
5253

53-
ajet_config = get_ajet_config_from_trinity_side()
5454
if ajet_config.ajet.enable_experimental_interchange_server:
5555
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
5656
start_interchange_server(ajet_config)

ajet/backbone/main_verl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def run_ppo(config) -> None:
5555
# Check if Ray is not initialized
5656
if not ray.is_initialized():
5757
# this is for local ray cluster
58-
runtime_env = get_runtime_env()
58+
runtime_env = get_runtime_env(config)
5959
print_dict(runtime_env["env_vars"], "runtime_env")
6060
ray.init(
6161
runtime_env=runtime_env,

ajet/backbone/main_vllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def run(config):
180180
def main(config):
181181
from omegaconf import OmegaConf
182182
OmegaConf.resolve(config)
183-
runtime_env = get_runtime_env()
183+
runtime_env = get_runtime_env(config)
184184
os.environ.update(runtime_env["env_vars"])
185185
# atexit.register(lambda: print("Process exiting, performing cleanup..."))
186186

ajet/backbone/warm_up.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,42 @@ def init_parallel_rollout_logger(experiment_name):
4343

4444

4545

46+
def warm_up_task_judge_when_needed(config):
47+
if config.ajet.task_judge.judge_type == "rubrics_auto_grader":
48+
from ajet.task_judge.rm_auto_grader_judge import AutoGraderJudge
49+
50+
judge = AutoGraderJudge(config)
51+
asyncio.run(judge.generate_rubrics_from_samples())
52+
asyncio.run(judge.load_rubrics_from_cache())
53+
54+
55+
def clean_up_tmp_ajet_dir(config):
56+
"""Clean up old IPC socket files in /tmp/ajet directory."""
57+
import time
58+
if config.ajet.enable_experimental_interchange_server is False:
59+
return
60+
61+
tmp_dir = "/tmp/ajet"
62+
if not os.path.exists(tmp_dir):
63+
return
64+
current_time = time.time()
65+
ttl = 4 * 3600
66+
try:
67+
for filename in os.listdir(tmp_dir):
68+
if not filename.endswith(".sock"):
69+
continue
70+
71+
file_path = os.path.join(tmp_dir, filename)
72+
try:
73+
print(current_time - os.path.getmtime(file_path))
74+
if current_time - os.path.getmtime(file_path) > ttl:
75+
os.remove(file_path)
76+
except OSError:
77+
pass
78+
except OSError:
79+
pass
80+
81+
4682
def warm_up_process(config):
4783
"""
4884
Process level warm up
@@ -65,12 +101,4 @@ def warm_up_process(config):
65101
experiment_name = config.ajet.experiment_name
66102
init_parallel_rollout_logger(experiment_name)
67103
warm_up_task_judge_when_needed(config)
68-
69-
70-
def warm_up_task_judge_when_needed(config):
71-
if config.ajet.task_judge.judge_type == "rubrics_auto_grader":
72-
from ajet.task_judge.rm_auto_grader_judge import AutoGraderJudge
73-
74-
judge = AutoGraderJudge(config)
75-
asyncio.run(judge.generate_rubrics_from_samples())
76-
asyncio.run(judge.load_rubrics_from_cache())
104+
clean_up_tmp_ajet_dir(config)

ajet/default_config/ajet_default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ ajet:
280280
save_trajectory_as_json_file: False
281281

282282

283-
# the experimental reverse proxy feature that allows `tuner.as_oai_baseurl_apikey` feature
283+
# the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature
284284
enable_experimental_interchange_server: True
285285
interchange_server:
286286
interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node)

ajet/tuner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def __init__(
2626
self.context_tracker = context_tracker
2727
self.llm_inference_fn = llm_inference_fn
2828
self.target2proxy_registry: dict[str, dict[str,TunerTypeUnion]] = {}
29-
if config.ajet.enable_experimental_interchange_server:
29+
self.enable_interchange_server = config.ajet.enable_experimental_interchange_server
30+
if self.enable_interchange_server:
3031
self.proxy_client_started = False
3132

3233

@@ -104,7 +105,7 @@ def as_oai_baseurl_apikey(
104105
```
105106
"""
106107

107-
assert self.config.ajet.enable_experimental_interchange_server, "Please enable `ajet.enable_experimental_interchange_server` in yaml config to use `as_oai_baseurl_apikey` feature."
108+
assert self.enable_interchange_server, "Please enable `ajet.enable_experimental_interchange_server` in yaml config to use `as_oai_baseurl_apikey` feature."
108109
if self.proxy_client_started is False:
109110
self.proxy_client_started = True
110111
self._enable_experimental_interchange_server(self.llm_inference_fn)
@@ -171,7 +172,7 @@ def get_context_tracker(self) -> MultiAgentContextTracker:
171172

172173
def _enable_experimental_interchange_server(self, llm_inference_fn):
173174
# experimental reverse proxy start
174-
if self.config.ajet.enable_experimental_interchange_server:
175+
if self.enable_interchange_server:
175176
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_client import InterchangeClient
176177
self.interchange_client = InterchangeClient(
177178
episode_uuid=self.context_tracker.episode_uuid,
@@ -184,6 +185,6 @@ def _enable_experimental_interchange_server(self, llm_inference_fn):
184185

185186
def terminate_episode(self):
186187
# experimental reverse proxy cleanup
187-
if self.config.ajet.enable_experimental_interchange_server:
188+
if self.enable_interchange_server:
188189
if (self.proxy_client_started is True) and hasattr(self, "interchange_client"):
189190
self.interchange_client._should_terminate = True

ajet/tuner_deprecated.py

Lines changed: 0 additions & 194 deletions
This file was deleted.

ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from openai.resources.chat.chat import Chat, AsyncChat
1313
from openai.resources.completions import AsyncCompletions
1414
from openai import OpenAI, AsyncOpenAI
15-
from ajet.utils.free_port import find_free_port
15+
from ajet.utils.networking import find_free_port
1616
from .experimental.as_oai_model_client import generate_auth_token
1717

1818
if TYPE_CHECKING:
@@ -47,9 +47,12 @@ def __init__(
4747
episode_contect_address: str,
4848
**kwargs,
4949
):
50+
5051
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
5152
assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set"
52-
base_url = f"http://localhost:{port}/v1"
53+
master_node_ip = os.getenv("MASTER_NODE_IP", "localhost")
54+
55+
base_url = f"http://{master_node_ip}:{port}/v1"
5356
api_key = generate_auth_token(
5457
agent_name=agent_name,
5558
target_tag=target_tag,

0 commit comments

Comments
 (0)