Skip to content

Commit ac6ab47

Browse files
committed
allow user to use any port as swarm port
1 parent 720abc6 commit ac6ab47

6 files changed

Lines changed: 24 additions & 27 deletions

File tree

ajet/swarm_cli.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ def start_swarm_server(env, config, port):
2929
)
3030

3131
# Set the port in the config
32-
if hasattr(config.ajet, 'experimental_interchange_server_port'):
33-
config.ajet.experimental_interchange_server_port = port
32+
config.ajet.interchange_server.interchange_server_port = port
3433

3534
from ajet.tuner_lib.experimental.as_oai_model_server import (
3635
start_interchange_server,
@@ -62,7 +61,7 @@ def cmd_start(args):
6261
exp_name,
6362
exp_config,
6463
) = prepare_experiment_config(
65-
yaml_path, exp_dir, args.backbone, storage=False
64+
yaml_path, exp_dir, "verl", storage=False
6665
)
6766

6867
# Setup environment variables
@@ -75,7 +74,7 @@ def __init__(self, conf, backbone, exp_dir):
7574
self.swarm_overwatch = ""
7675
self.debug = ""
7776

78-
swarm_args = SwarmArgs(args.conf, args.backbone, args.exp_dir)
77+
swarm_args = SwarmArgs(args.conf, "verl", args.exp_dir)
7978
env, exp_config = setup_environment_vars(swarm_args, exp_config, main_yaml_fp)
8079

8180
# Start swarm server
@@ -86,8 +85,8 @@ def cmd_overwatch(args):
8685
"""Handle the 'overwatch' subcommand."""
8786
from ajet.utils.swarm_overwatch import start_overwatch
8887

89-
logger.info(f"Starting Swarm Overwatch for server: {args.swarm_port}")
90-
start_overwatch(args.swarm_port, refresh_interval=args.refresh_interval)
88+
logger.info(f"Starting Swarm Overwatch for server: {args.swarm_url}")
89+
start_overwatch(args.swarm_url, refresh_interval=args.refresh_interval)
9190

9291

9392
def main():
@@ -117,19 +116,13 @@ def main():
117116
required=False,
118117
help="Path to experiment directory",
119118
)
120-
parser_start.add_argument(
121-
"--backbone",
122-
type=str,
123-
default="verl",
124-
required=False,
125-
help="verl or trinity or debug",
126-
)
119+
127120
parser_start.set_defaults(func=cmd_start)
128121

129122
# Subcommand: overwatch
130123
parser_overwatch = subparsers.add_parser("overwatch", help="Monitor the swarm server")
131124
parser_overwatch.add_argument(
132-
"--swarm-port",
125+
"--swarm-url",
133126
type=str,
134127
default="http://localhost:10086",
135128
required=False,

ajet/tuner_lib/experimental/as_swarm_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,8 @@ def print_rollout_stat(self):
469469
"Completed episodes: (current) / (required)": f"{stat.get('completed_episodes', 0)} / {stat.get('completed_episode_target', 0)}",
470470
"Average episodes per task: (current) / (expected)": f"{stat.get('average_episodes_per_task', 0):.2f} / {stat.get('task_expected_num_repeat', 0)}",
471471
"Completed num-dummy tasks: (current) / (required)": f"{stat.get('completed_non_dummy_tasks', 0)} / {stat.get('completed_task_target', 0)}",
472-
"Tasks (Number of episodes completed for each task)": task_buffer
472+
"Tasks (Number of episodes completed for each task)": task_buffer,
473+
"Hint": f"Please run `ajet-swarm overwatch --swarm-url={self.server_url}` to get more details."
473474
}
474475
print_dict(stat, mod="console", header="Current Swarm Rollout Pool Information")
475476
except:

ajet/tuner_lib/experimental/as_swarm_server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -625,8 +625,8 @@ async def abort_episode(req: EndEpisodeRequest):
625625
if VERBOSE:
626626
logger.info(f"Running [{episode_uuid}]: /abort_episode")
627627

628-
assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id"
629-
assert workflow_output.metadata["task_id"] == task_id, "workflow_output.metadata.task_id must match req.task_id"
628+
# assert "task_id" in workflow_output.metadata, "workflow_output.metadata must contain task_id"
629+
# assert workflow_output.metadata["task_id"] == task_id, "workflow_output.metadata.task_id must match req.task_id"
630630

631631
if "episodes" not in shared_mem_dict:
632632
logger.error(f"[server] No episodes registered yet.")
@@ -715,6 +715,7 @@ async def get_current_batch_rollout_pool_information():
715715
running_episode_details[es.episode_uuid] = {
716716
"episode_status": es.episode_status,
717717
"time_since_last_activity": f"{time_since_last_activity:.1f}s",
718+
"discard_episode_timeout": f"{es.discard_episode_timeout:.1f}s",
718719
}
719720
pool_info.running_episode_details = running_episode_details if running_episode_details else None
720721

ajet/tuner_lib/experimental/swarm_overwatch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class CurrentBatchRolloutPoolInformation(BaseModel):
1212
completed_non_dummy_task_target: int = 0
1313
task_expected_num_repeat: int = 0
1414
completed_tasks_details: Dict[str, List[str]] = {} # task_id -> list of episode_uuids
15-
running_episode_details: Dict[str, Dict[str, str]] | None = None # episode_uuid -> { "episode_status": ..., "time_since_last_activity": ...}
15+
running_episode_details: Dict[str, Dict[str, str]] | None = None # episode_uuid -> { "episode_status": ..., "time_since_last_activity": ..., "discard_episode_timeout": ...}
1616
engine_status: str | None = None
1717
global_step: int | None = None
1818
booting_start_time: float | None = None # timestamp when ENGINE.BOOTING started

ajet/utils/swarm_overwatch.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def create_summary_table(self, info: CurrentBatchRolloutPoolInformation) -> Tabl
113113
title_prefix = "" if is_active else "[WAITING ENGINE.ROLLING] "
114114

115115
table = Table(
116-
title=f"{title_prefix}Rollout Pool Summary (Progress to Hit Next Weight Update)",
116+
title=f"{title_prefix}Completed Episode Pool Summary (Progress to Hit Next Weight Update)",
117117
show_header=True,
118118
header_style="bold magenta",
119119
border_style=border_style,
@@ -140,13 +140,13 @@ def create_summary_table(self, info: CurrentBatchRolloutPoolInformation) -> Tabl
140140

141141
# Episodes
142142
ep_cur, ep_tgt, ep_pct = self.create_progress_bar(
143-
info.completed_episodes, info.completed_episode_target, "Episodes"
143+
info.completed_episodes, info.completed_episode_target, "Completed Episodes"
144144
)
145145
ep_bar = self._create_text_bar(ep_pct)
146146
ep_metric = (
147-
"*Episodes (chosen)*"
147+
"-> *Completed Episodes (chosen)*"
148148
if highlight_episodes
149-
else "Episodes"
149+
else "Completed Episodes"
150150
)
151151
ep_style = "bold green" if highlight_episodes else None
152152
table.add_row(
@@ -166,7 +166,7 @@ def create_summary_table(self, info: CurrentBatchRolloutPoolInformation) -> Tabl
166166
)
167167
task_bar = self._create_text_bar(task_pct)
168168
task_metric = (
169-
"*Completed Tasks (chosen)*" if highlight_tasks else "Completed Tasks"
169+
"-> *Completed Tasks (chosen)*" if highlight_tasks else "Completed Tasks"
170170
)
171171
task_style = "bold green" if highlight_tasks else None
172172
table.add_row(
@@ -188,7 +188,7 @@ def create_summary_table(self, info: CurrentBatchRolloutPoolInformation) -> Tabl
188188
)
189189
nd_bar = self._create_text_bar(nd_pct)
190190
nd_metric = (
191-
"*Completed Non-Dummy Tasks (chosen)*"
191+
"-> *Completed Non-Dummy Tasks (chosen)*"
192192
if highlight_non_dummy
193193
else "Completed Non-Dummy Tasks"
194194
)
@@ -245,7 +245,7 @@ def create_running_episodes_table(
245245

246246
table.add_column("Episode UUID", style="cyan", no_wrap=True, width=20, overflow="ellipsis")
247247
table.add_column("Status", style="green", width=15)
248-
table.add_column("Time Since Last Activity", style="yellow", width=30)
248+
table.add_column("Last Req / Patience", style="yellow", width=30)
249249

250250
if not info.running_episode_details:
251251
table.add_row("[dim]No running episodes[/dim]", "", "")
@@ -259,10 +259,12 @@ def create_running_episodes_table(
259259
)
260260

261261
for episode_uuid, details in sorted_episodes[:30]:
262+
last_req = details["time_since_last_activity"]
263+
patience = details.get("discard_episode_timeout", "N/A")
262264
table.add_row(
263265
episode_uuid[:40] if len(episode_uuid) > 40 else episode_uuid,
264266
details["episode_status"],
265-
details["time_since_last_activity"],
267+
f"{last_req} / {patience}",
266268
)
267269

268270
if len(sorted_episodes) > 30:

tutorial/example_academic_trans/trans_roll.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
LOCAL_NUM_EPOCH = 1
1414
LOCAL_MAX_PARALLEL = 32
1515
LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/agentjet/agentjet/tmp/arxiv_papers/train.parquet"
16-
REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url
16+
REMOTE_SWARM_URL = "http://localhost:10099" # Change to your swarm remote url
1717

1818
# --------- configurations that take effect remotely -------------
1919
REMOTE_BATCH_SIZE = 8

0 commit comments

Comments
 (0)