Skip to content

Commit 75a4321

Browse files
committed
update yaml configuration
1 parent 8eae43c commit 75a4321

19 files changed

+135
-135
lines changed

ajet/backbone/main_trinity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def patched_trainer_get_actor(cls, config: Config):
5151
Trainer.get_actor = classmethod(patched_trainer_get_actor)
5252

5353
ajet_config = get_ajet_config_from_trinity_side()
54-
if ajet_config.ajet.enable_experimental_reverse_proxy:
54+
if ajet_config.ajet.enable_experimental_interchange_server:
5555
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
56-
start_interchange_server(ajet_config.ajet.experiment_dir)
56+
start_interchange_server(ajet_config)
5757

5858

5959
if __name__ == "__main__":

ajet/backbone/main_verl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,9 @@ def run(self, config):
246246

247247
from ajet.backbone.trainer_verl import AjetRayPPOTrainer
248248

249-
if config.ajet.enable_experimental_reverse_proxy:
249+
if config.ajet.enable_experimental_interchange_server:
250250
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
251-
start_interchange_server(config.ajet.experiment_dir)
251+
start_interchange_server(config)
252252

253253
# Initialize the PPO trainer.
254254
trainer = AjetRayPPOTrainer(

ajet/backbone/main_vllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ def main(config):
184184
os.environ.update(runtime_env["env_vars"])
185185
# atexit.register(lambda: print("Process exiting, performing cleanup..."))
186186

187-
if config.ajet.enable_experimental_reverse_proxy:
187+
if config.ajet.enable_experimental_interchange_server:
188188
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
189-
start_interchange_server(config.ajet.experiment_dir)
189+
start_interchange_server(config)
190190

191191
def companion_launch():
192192
import torch

ajet/backbone/trainer_trinity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(
116116

117117
async def run_async(self):
118118
ajet_config = get_ajet_config_from_trinity_side()
119-
if ajet_config.ajet.enable_experimental_reverse_proxy:
119+
if ajet_config.ajet.enable_experimental_interchange_server:
120120
raise NotImplementedError(
121121
"The experimental reverse proxy is not supported in Trinity backbone yet."
122122
)

ajet/backbone/trainer_verl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def fit(self): # noqa: C901
835835
self.global_steps += 1
836836

837837
# # when enabled oai request interchange, we need to clear the cache from time to time
838-
# if self.config.ajet.enable_experimental_reverse_proxy:
838+
# if self.config.ajet.enable_experimental_interchange_server:
839839
# from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear
840840
# ensure_dat_interchange_server_cache_clear()
841841

ajet/backbone/warm_up.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""
2+
Process level warm up
3+
"""
4+
5+
16
import asyncio
27
import logging
38
import os
@@ -32,6 +37,9 @@ def init_parallel_rollout_logger(experiment_name):
3237

3338
target_logger = logging.getLogger("vllm.entrypoints.openai.tool_parsers.hermes_tool_parser")
3439
target_logger.setLevel(logging.CRITICAL)
40+
logging.getLogger("httpx").setLevel(logging.WARNING)
41+
42+
3543

3644
def warm_up_process(config):
3745
"""

ajet/default_config/ajet_default.yaml

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,6 @@ ajet:
66
backbone: debug # `debug` or `trinity` or `verl`
77

88

9-
# the experimental reverse proxy feature that allows `tuner.as_oai_baseurl_apikey` feature
10-
enable_experimental_reverse_proxy: False
11-
12-
# submit llm infer submit method
13-
llm_infer_submit_method: "async" # options: "sync", "async"
14-
15-
task_runner:
16-
wrapper_type: "asyncio-with-gc"
17-
wrapper_multiprocessing_timeout: 3600 # in seconds
18-
# - wrapper_type: "asyncio-with-gc": safe, with periodic garbage collection to prevent event loop leaks (recommended)
19-
# - wrapper_type: "asyncio": fast, but may cause event loop leak in long run
20-
# - wrapper_type: "multi-processing": safe, but resource consuming
21-
229
model:
2310
# which model should be trained
2411
path: /path/to/model/such/as/Qwen/Qwen2___5-14B-Instruct
@@ -42,7 +29,7 @@ ajet:
4229
force_disable_toolcalls: False
4330

4431
# maximum number of parallel environments / simulate workers
45-
max_env_worker: 128
32+
max_env_worker: 64
4633

4734
# step reward gamma (experimental, do not change)
4835
gamma: 1.0
@@ -293,7 +280,28 @@ ajet:
293280
save_trajectory_as_json_file: False
294281

295282

283+
# the experimental reverse proxy feature that allows `tuner.as_oai_baseurl_apikey` feature
284+
enable_experimental_interchange_server: False
285+
interchange_server:
286+
interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node)
287+
interchange_server_port: 'auto'
288+
num_fastapi_process: 2 # 1, 2 or 4 is fine
289+
max_fastapi_threads: 128 # 64 or 128 is fine
290+
max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker`
291+
292+
293+
task_runner:
294+
# submit llm infer submit method
295+
llm_infer_submit_method: "async" # options: "sync", "async"
296+
297+
# how to wrap the user-defined workflow
298+
wrapper_type: "asyncio-with-gc"
299+
# - wrapper_type: "asyncio-with-gc": safe, with periodic garbage collection to prevent event loop leaks (recommended)
300+
# - wrapper_type: "asyncio": fast, but may cause event loop leak in long run
301+
# - wrapper_type: "multi-processing": safe, but resource consuming
296302

303+
# when `wrapper_type` is `multi-processing`, the timeout for each task
304+
wrapper_multiprocessing_timeout: 3600 # in seconds
297305

298306
# DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN.
299307
execute_test: False # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN.

ajet/task_rollout/async_llm_bridge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ async def run_infer(
534534
# otherwise, for abnormal output, can still proceed, but we do not track output anymore
535535

536536
# run llm inference ✨
537-
if self.config.ajet.llm_infer_submit_method == "sync":
537+
if self.config.ajet.task_runner.llm_infer_submit_method == "sync":
538538
llm_output = await asyncio.to_thread(
539539
self.llm_inference_fn, converted_message, custom_sampling_params, tools
540540
)

ajet/task_rollout/single_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def rollout_env_worker(
8585
"""
8686
sampling_params = get_sample_params(mode, self.config)
8787

88-
if self.config.ajet.llm_infer_submit_method == "sync":
88+
if self.config.ajet.task_runner.llm_infer_submit_method == "sync":
8989
llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_sync(
9090
sampling_params=sampling_params
9191
)

ajet/tuner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
self.context_tracker = context_tracker
2727
self.llm_inference_fn = llm_inference_fn
2828
self.target2proxy_registry: dict[str, dict[str,TunerTypeUnion]] = {}
29-
if config.ajet.enable_experimental_reverse_proxy:
29+
if config.ajet.enable_experimental_interchange_server:
3030
self.proxy_client_started = False
3131

3232

@@ -104,10 +104,10 @@ def as_oai_baseurl_apikey(
104104
```
105105
"""
106106

107-
assert self.config.ajet.enable_experimental_reverse_proxy, "Please enable `ajet.enable_experimental_reverse_proxy` in yaml config to use `as_oai_baseurl_apikey` feature."
107+
assert self.config.ajet.enable_experimental_interchange_server, "Please enable `ajet.enable_experimental_interchange_server` in yaml config to use `as_oai_baseurl_apikey` feature."
108108
if self.proxy_client_started is False:
109-
self._enable_experimental_interchange_server(self.llm_inference_fn)
110109
self.proxy_client_started = True
110+
self._enable_experimental_interchange_server(self.llm_inference_fn)
111111
baseurl_apikey_model = OpenaiClientBaseUrlTuner(
112112
config=self.config,
113113
context_tracker=self.context_tracker,
@@ -171,7 +171,7 @@ def get_context_tracker(self) -> MultiAgentContextTracker:
171171

172172
def _enable_experimental_interchange_server(self, llm_inference_fn):
173173
# experimental reverse proxy start
174-
if self.config.ajet.enable_experimental_reverse_proxy:
174+
if self.config.ajet.enable_experimental_interchange_server:
175175
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_client import InterchangeClient
176176
self.interchange_client = InterchangeClient(
177177
episode_uuid=self.context_tracker.episode_uuid,
@@ -184,6 +184,6 @@ def _enable_experimental_interchange_server(self, llm_inference_fn):
184184

185185
def terminate_episode(self):
186186
# experimental reverse proxy cleanup
187-
if self.config.ajet.enable_experimental_reverse_proxy:
187+
if self.config.ajet.enable_experimental_interchange_server:
188188
if (self.proxy_client_started is True) and hasattr(self, "interchange_client"):
189189
self.interchange_client._should_terminate = True

0 commit comments

Comments
 (0)