Skip to content

Commit d10b5bc

Browse files
authored
refactor agentjet job system (#11)
* refactor * stage werewolve example * refactor: update configuration and improve swarm client functionality * improve communication protocol * bug patch * patch save dir bug * force check agentscope version * sharing httpx client * fix memory leak * add thread safety to cache operations and implement LRU eviction * implement skills and skillbench example
1 parent a1181cc commit d10b5bc

File tree

61 files changed

+1186
-624
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1186
-624
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,5 @@ modelscope_cache
170170
prompts
171171
swarmexp
172172
swarmlog
173+
werewolves_swarm
174+
.claude

ajet/backbone/main_trinity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def patched_trainer_get_actor(cls, config: Config):
5252
Explorer.get_actor = classmethod(patched_explorer_get_actor)
5353
Trainer.get_actor = classmethod(patched_trainer_get_actor)
5454

55-
if ajet_config.ajet.enable_experimental_interchange_server:
55+
if ajet_config.ajet.enable_interchange_server:
5656
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
5757
start_interchange_server(ajet_config)
5858

ajet/backbone/main_verl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def run_ppo(config: DictConfig) -> None:
6767
def on_shutdown():
6868
if ray.is_initialized():
6969
ray.shutdown()
70-
if config.ajet.enable_experimental_interchange_server:
70+
if config.ajet.enable_interchange_server:
7171
if config.ajet.enable_swarm_mode:
7272
from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status
7373
print("Changing engine status to OFFLINE before shutdown...")
@@ -250,7 +250,7 @@ def run(self, config):
250250

251251
from ajet.backbone.trainer_verl import AjetRayPPOTrainer
252252

253-
if config.ajet.enable_experimental_interchange_server:
253+
if config.ajet.enable_interchange_server:
254254
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
255255
start_interchange_server(config)
256256

ajet/backbone/main_vllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def main(config):
186186
os.environ.update(runtime_env["env_vars"])
187187
# atexit.register(lambda: print("Process exiting, performing cleanup..."))
188188

189-
if config.ajet.enable_experimental_interchange_server:
189+
if config.ajet.enable_interchange_server:
190190
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
191191
start_interchange_server(config)
192192
if config.ajet.enable_swarm_mode:

ajet/backbone/trainer_verl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def init_workers(self):
457457
)
458458

459459
def _update_interchange_server_status_flag(self, status: str):
460-
if self.config.ajet.enable_experimental_interchange_server:
460+
if self.config.ajet.enable_interchange_server:
461461
if self.config.ajet.enable_swarm_mode:
462462
from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status
463463
http_change_engine_status(self.config, status, global_step=self.global_steps)
@@ -858,7 +858,7 @@ def fit(self): # noqa: C901
858858
self.global_steps += 1
859859

860860
# # when enabled oai request interchange, we need to clear the cache from time to time
861-
# if self.config.ajet.enable_experimental_interchange_server:
861+
# if self.config.ajet.enable_interchange_server:
862862
# from ajet.tuner_lib.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear
863863
# ensure_dat_interchange_server_cache_clear()
864864

ajet/backbone/warm_up.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def warm_up_task_judge_when_needed(config):
5454
def clean_up_tmp_ajet_dir(config):
5555
"""Clean up old IPC socket files in /tmp/ajet directory."""
5656
import time
57-
if config.ajet.enable_experimental_interchange_server is False:
57+
if config.ajet.enable_interchange_server is False:
5858
return
5959

6060
tmp_dir = "/tmp/ajet"

ajet/context_tracker/multiagent_tracking.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,6 @@ def extract_text_content_from_content_dict(self, msg):
8484
# },
8585
# ],
8686
# }
87-
# or tool_result format?? not observed yet:
88-
# msg = {
89-
# "role": "tool",
90-
# "content": [
91-
# {
92-
# "type": "tool_result",
93-
# "id": "call_xxx",
94-
# "output": "tool output content",
95-
# "name": "tool_name"
96-
# },
97-
# ],
98-
# }
99-
10087

10188
str_content = ""
10289
for item in msg["content"]:
@@ -332,6 +319,7 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline):
332319
)
333320
):
334321
logger.bind(exception=True).info(f"General Warning: merge failure discovered.\n")
322+
# from ajet import bp; bp("SWARM")
335323
return
336324

337325

@@ -346,7 +334,9 @@ def detect_tool_call_madness(self, llm_output):
346334
# llm_output["tool_calls"] is not None, and is not []
347335
tool_calls = llm_output["tool_calls"]
348336
if "wrong_toolcall" in self.config.ajet.rollout.compute_madness_checklist:
349-
copy_tool_calls = copy.deepcopy(tool_calls)
337+
# copy_tool_calls = copy.deepcopy(tool_calls)
338+
# Shallow copy is sufficient - we're only reading the data
339+
copy_tool_calls = tool_calls
350340
wrong_toolcall = False
351341
for i in range(len(copy_tool_calls)):
352342
if ("function" in copy_tool_calls[i]) and (

ajet/copilot/job.py

Lines changed: 99 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
from __future__ import annotations
99

1010
import os
11+
import time
12+
import yaml
1113
import tempfile
12-
from types import SimpleNamespace
13-
from typing import Any, Callable, Union
1414

15-
import yaml
15+
from types import SimpleNamespace
16+
from typing import Any, Callable, Union, cast
1617
from loguru import logger
17-
18-
1918
from ajet.default_config.ajet_default import Config
2019
from ajet.utils.config_utils import (
2120
expand_ajet_hierarchical_config,
@@ -30,70 +29,118 @@
3029
setup_environment_vars,
3130
)
3231

33-
DEFAULT_DIR = "saved_experiments"
32+
33+
def override_current_yaml_value_if_given(override_value, current_value):
34+
if override_value is not None:
35+
return override_value
36+
else:
37+
return current_value
38+
39+
def _set_nested_attr(obj, attr_path: str, value):
40+
keys = attr_path.split(".")
41+
for key in keys[:-1]:
42+
obj = getattr(obj, key)
43+
setattr(obj, keys[-1], value)
44+
45+
def _get_nested_attr(obj, attr_path: str):
46+
for key in attr_path.split("."):
47+
obj = getattr(obj, key)
48+
return obj
3449

3550
class AgentJetJob:
36-
"""Lightweight builder that launches AgentJet training as a subprocess."""
51+
"""
52+
arg: base_yaml_config + **kwargs (yaml config, then override with kwargs)
53+
arg: base_yaml_config (yaml config)
54+
arg: **kwargs (yaml config, then override with kwargs)
55+
"""
3756

3857
def __init__(
3958
self,
40-
backbone: str = "verl",
41-
model: str = "Qwen/Qwen2___5-7B-Instruct",
42-
n_gpu: int = 8,
43-
algorithm: str = "grpo",
44-
project_name="ajet-swarm",
45-
experiment_name="test",
46-
n_gpu_for_infer: int | None = None, # only for trinity backbone
47-
num_repeat: int = 8,
48-
batch_size: int = 32,
49-
swarm_mode: bool = True,
50-
sample_collection_method: str = "rollout_until_finish_enough_tasks",
51-
*kwargs,
59+
base_yaml_config: str | None = None,
60+
experiment_dir: str | None = None,
61+
project_name: str | None = None,
62+
experiment_name: str | None = None,
63+
n_gpu: int | None = None,
64+
model: str | None = None,
65+
algorithm: str | None = None,
66+
num_repeat: int | None = None,
67+
batch_size: int | None = None,
68+
swarm_mode: bool | None = None,
69+
swarm_mode_sample_collection_method: str | None = None,
70+
max_env_worker: int | None = None,
71+
backbone: str | None = None,
5272
) -> None:
53-
self.backbone = backbone
54-
self.exp_dir = DEFAULT_DIR
55-
self.project_name = project_name
56-
self.exp_name = experiment_name
57-
self.sample_collection_method = sample_collection_method
58-
if swarm_mode:
59-
default_yaml = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml"))
73+
74+
if base_yaml_config is None:
75+
base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml"))
6076
else:
61-
default_yaml = None
62-
self.config_as_dict: dict = self.build_job_from_yaml(default_yaml)
77+
logger.warning(f"Reading config from {base_yaml_config}.")
78+
time.sleep(1)
79+
self.config_as_dict: dict = self.build_job_from_yaml(base_yaml_config)
6380
self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict)
6481

65-
self.config.ajet.experiment_name = experiment_name
66-
self.config.ajet.backbone = backbone
67-
self.config.ajet.model.path = model
68-
self.config.ajet.trainer_common.n_gpus_per_node = n_gpu
69-
self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm
70-
self.config.ajet.rollout.num_repeat = num_repeat
71-
self.config.ajet.data.train_batch_size = batch_size
72-
self.config.ajet.enable_swarm_mode = swarm_mode
73-
self.config.ajet.swarm_mode_sample_collection_method = sample_collection_method
74-
if n_gpu_for_infer is None and backbone == "trinity":
75-
raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.")
76-
if (n_gpu_for_infer is not None) and backbone == "verl":
77-
raise ValueError("n_gpu_for_infer is only for trinity backbone, please set it to `None`.")
78-
else:
79-
if backbone == "trinity":
80-
assert isinstance(n_gpu_for_infer, int), f"`n_gpu_for_infer` should be int, got {type(n_gpu_for_infer)}."
81-
assert n_gpu_for_infer < n_gpu, "`n_gpu_for_infer` should be less than `n_gpu`."
82-
self.config.ajet.rollout.n_vllm_engine = n_gpu_for_infer
83-
self.config.ajet.rollout.tensor_model_parallel_size = 1
82+
self.base_yaml_config: str = cast(str, base_yaml_config) # currently may be None, but will be set later
83+
self.experiment_dir: str = cast(str, experiment_dir)
84+
self.project_name: str = cast(str, project_name)
85+
self.experiment_name: str = cast(str, experiment_name)
86+
self.n_gpu: int = cast(int, n_gpu)
87+
self.model: str = cast(str, model)
88+
self.algorithm: str = cast(str, algorithm)
89+
self.num_repeat: int = cast(int, num_repeat)
90+
self.batch_size: int = cast(int, batch_size)
91+
self.swarm_mode: bool = cast(bool, swarm_mode)
92+
self.swarm_mode_sample_collection_method: str = cast(str, swarm_mode_sample_collection_method)
93+
self.max_env_worker: int = cast(int, max_env_worker)
94+
self.backbone: str = cast(str, backbone)
95+
96+
# see `ajet/default_config/ajet_ts_default.yaml`
97+
overrides = {
98+
"ajet.experiment_dir": "experiment_dir",
99+
"ajet.project_name": "project_name",
100+
"ajet.experiment_name": "experiment_name",
101+
"ajet.model.path": "model",
102+
"ajet.trainer_common.n_gpus_per_node": "n_gpu",
103+
"ajet.trainer_common.algorithm.adv_estimator": "algorithm",
104+
"ajet.rollout.num_repeat": "num_repeat",
105+
"ajet.data.train_batch_size": "batch_size",
106+
"ajet.enable_swarm_mode": "swarm_mode",
107+
"ajet.swarm_mode_sample_collection_method": "swarm_mode_sample_collection_method",
108+
"ajet.rollout.max_env_worker": "max_env_worker",
109+
"ajet.backbone": "backbone",
110+
}
111+
112+
# if any value given in kwargs, override the corresponding value in config
113+
for attr_path, override_val in overrides.items():
114+
# get value from yaml config
115+
# >> e.g. current_model = self.config.model.path
116+
current_val = _get_nested_attr(self.config, attr_path)
117+
118+
# if override_val (given in __init__) is not None, use it to override the value from yaml config
119+
# >> e.g. new_model = self.model if (self.model is not None) else current_model
120+
new_val = override_current_yaml_value_if_given(getattr(self, override_val), current_val)
121+
122+
# write final value to `self.config``
123+
# >> e.g. self.config.model.path = new_model
124+
_set_nested_attr(self.config, attr_path, new_val)
125+
126+
# write final value to `self`
127+
# >> e.g. self.model = new_model
128+
setattr(self, override_val, new_val)
129+
130+
if self.backbone == "trinity":
131+
raise NotImplementedError("Trinity backbone is not yet supported in AgentJetJob.")
132+
84133

85134
def build_job_from_yaml(self, yaml_path: str | None) -> dict:
86135
self.config_as_dict = read_ajet_hierarchical_config(
87136
yaml_path,
88-
exp_name=self.exp_name,
89-
backbone=self.backbone,
90137
write_to=None,
91-
exp_dir=self.exp_dir,
92138
)
93139
self.config_as_dict = expand_ajet_hierarchical_config(self.config_as_dict, write_to=None)
94140
logger.info(f"Built AgentJet job config: {yaml_path}")
95141
return self.config_as_dict
96142

143+
97144
def dump_job_as_yaml(self, yaml_path: str) -> str:
98145
if os.path.dirname(yaml_path):
99146
os.makedirs(os.path.dirname(yaml_path), exist_ok=True)
@@ -102,6 +149,7 @@ def dump_job_as_yaml(self, yaml_path: str) -> str:
102149
logger.info(f"Saved training config to {yaml_path}")
103150
return yaml_path
104151

152+
105153
def set_workflow(
106154
self, workflow: Union[str, Callable[..., Any]], ensure_reward_in_workflow: bool = False
107155
) -> "AgentJetJob":
@@ -110,6 +158,7 @@ def set_workflow(
110158
# ensure_reward_in_workflow
111159
return self
112160

161+
113162
def set_data(
114163
self,
115164
type: str,
@@ -136,60 +185,3 @@ def set_data(
136185

137186
return self
138187

139-
def tune(self, *args, **kwargs) -> "AgentJetJob":
140-
import ray
141-
ast_cfg = self.config.ajet
142-
if not ast_cfg.rollout or not ast_cfg.rollout.user_workflow:
143-
raise ValueError("Workflow must be set via set_workflow before tuning.")
144-
if not ast_cfg.task_reader:
145-
raise ValueError("Data source must be set via set_data before tuning.")
146-
147-
backbone = self.config.ajet.backbone
148-
exp_dir = self.config.ajet.experiment_dir
149-
150-
with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".yaml") as temp_yaml:
151-
yaml_path = temp_yaml.name
152-
self.dump_job_as_yaml(yaml_path)
153-
args = SimpleNamespace(
154-
conf=yaml_path,
155-
backbone=backbone,
156-
exp_dir=exp_dir,
157-
with_logview=False,
158-
debug=False,
159-
)
160-
161-
if args.backbone != "debug":
162-
# Enforce GPU availability and free memory threshold before proceeding
163-
check_avail_gpu(min_free_ratio=0.95)
164-
165-
# finalize experiment config
166-
main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config(
167-
yaml_path, exp_dir, backbone
168-
)
169-
170-
# setup environment variables for ray
171-
env = setup_environment_vars(args, exp_config, main_yaml_fp)
172-
173-
# start ray if not already started
174-
if not ray.is_initialized():
175-
from ajet.utils.launch_utils import start_ray_service
176-
177-
start_ray_service(args, env)
178-
else:
179-
raise RuntimeError(
180-
"Ray is already initialized. Please shutdown existing Ray instance before starting a new tuning job."
181-
)
182-
183-
# start training process
184-
if args.conf and main_yaml_fp and exe_exp_base and exp_config:
185-
execute_training_process(
186-
args,
187-
get_backbone_target(args.backbone),
188-
main_yaml_fp,
189-
exe_exp_base,
190-
main_yaml_fp,
191-
env,
192-
exp_config,
193-
)
194-
195-
return self

0 commit comments

Comments
 (0)