Skip to content

Commit a191f71

Browse files
committed
improve communication protocol
1 parent e9bc0e1 commit a191f71

File tree

21 files changed

+264
-915
lines changed

21 files changed

+264
-915
lines changed

ajet/context_tracker/multiagent_tracking.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,6 @@ def extract_text_content_from_content_dict(self, msg):
8484
# },
8585
# ],
8686
# }
87-
# or tool_result format?? not observed yet:
88-
# msg = {
89-
# "role": "tool",
90-
# "content": [
91-
# {
92-
# "type": "tool_result",
93-
# "id": "call_xxx",
94-
# "output": "tool output content",
95-
# "name": "tool_name"
96-
# },
97-
# ],
98-
# }
99-
10087

10188
str_content = ""
10289
for item in msg["content"]:

ajet/default_config/ajet_default.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
ajet:
33
project_name: "ajet_default_project"
44
experiment_name: "read_yaml_name"
5-
experiment_dir: "auto" # {exp-dir}/{experiment_name}
5+
experiment_dir: "{exp-dir}/{experiment_name}" # {exp-dir}/{experiment_name}
66
backbone: verl # `debug` or `trinity` or `verl`
77

88

@@ -85,6 +85,7 @@ ajet:
8585
num_repeat: 1
8686

8787

88+
8889
task_reader:
8990
# how to read dataset / environment
9091
type: huggingface_dat_repo # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy`
@@ -306,8 +307,6 @@ ajet:
306307
swarm_mode_sample_collection_max_cached_episodes: 9999
307308

308309
task_runner:
309-
# submit llm infer submit method
310-
llm_infer_submit_method: "async" # options: "sync", "async"
311310

312311
# how to wrap the user-defined workflow
313312
wrapper_type: "asyncio-with-gc"

ajet/default_config/ajet_ts_default.yaml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
ajet:
33
project_name: "ajet_default_project"
44
experiment_name: "read_yaml_name"
5-
experiment_dir: "auto" # {exp-dir}/{experiment_name}
5+
experiment_dir: "{exp-dir}/{experiment_name}" # {exp-dir}/{experiment_name}
66
backbone: verl
77

88
model:
@@ -12,6 +12,10 @@ ajet:
1212
rollout:
1313
# the path to the workflow class
1414
user_workflow: null
15+
# maximum number of parallel environments / simulate workers
16+
max_env_worker: 128
17+
# how many times a task should be repeated
18+
num_repeat: 4
1519

1620
task_reader:
1721
type: random_dummy # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy`
@@ -53,12 +57,6 @@ ajet:
5357
train_batch_size: 32
5458
# [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps)
5559

56-
rollout:
57-
# maximum number of parallel environments / simulate workers
58-
max_env_worker: 128
59-
# how many times a task should be repeated
60-
num_repeat: 4
61-
6260
trainer_common:
6361
logger: tensorboard
6462
n_gpus_per_node: 8

ajet/launcher.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def main():
168168
from ajet.utils.swarm_overwatch import start_overwatch
169169

170170
logger.info(f"Starting Swarm Overwatch for server: {args.swarm_overwatch}")
171-
start_overwatch(args.swarm_overwatch, refresh_interval=1.0)
171+
start_overwatch(args.swarm_overwatch, refresh_interval=2.0)
172172
return
173173

174174
# Enforce GPU availability and free memory threshold before proceeding
@@ -204,7 +204,6 @@ def main():
204204

205205
# read configuration from yaml
206206
exp_config = None
207-
exp_dir = args.exp_dir or DEFAULT_DIR
208207
if args.swarm_server and (not args.conf):
209208
args.conf = os.path.abspath(
210209
os.path.join(
@@ -215,14 +214,18 @@ def main():
215214
"Please provide a valid config file for swarm server mode."
216215
)
217216
if args.conf:
217+
exp_dir = args.exp_dir or DEFAULT_DIR
218218
yaml_path = args.conf
219219
(
220220
main_yaml_fp,
221221
exe_exp_base,
222222
exp_name,
223223
exp_config,
224224
) = prepare_experiment_config(
225-
yaml_path, exp_dir, args.backbone, storage=(not args.swarm_server)
225+
yaml_path=yaml_path,
226+
exp_base_dir=exp_dir,
227+
backbone=args.backbone,
228+
storage=(not args.swarm_server)
226229
)
227230

228231
# setup environment variables

ajet/swarm_cli.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ def cmd_start(args):
6161
exp_name,
6262
exp_config,
6363
) = prepare_experiment_config(
64-
yaml_path, exp_dir, "verl", storage=False
64+
yaml_path=yaml_path,
65+
exp_base_dir=exp_dir,
66+
backbone="verl",
67+
storage=False
6568
)
6669

6770
# Setup environment variables
@@ -73,7 +76,6 @@ def __init__(self, conf, backbone, exp_dir):
7376
self.swarm_server = True
7477
self.swarm_overwatch = ""
7578
self.debug = ""
76-
7779
swarm_args = SwarmArgs(args.conf, "verl", args.exp_dir)
7880
env, exp_config = setup_environment_vars(swarm_args, exp_config, main_yaml_fp)
7981

@@ -131,9 +133,9 @@ def main():
131133
parser_overwatch.add_argument(
132134
"--refresh-interval",
133135
type=float,
134-
default=1.0,
136+
default=2.0,
135137
required=False,
136-
help="Refresh interval in seconds (default: 1.0)",
138+
help="Refresh interval in seconds (default: 2.0)",
137139
)
138140
parser_overwatch.set_defaults(func=cmd_overwatch)
139141

0 commit comments

Comments
 (0)