Skip to content

Commit df222cd

Browse files
committed
add swarm overwatch
1 parent a127b84 commit df222cd

File tree

5 files changed

+376
-47
lines changed

5 files changed

+376
-47
lines changed

ajet/launcher.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,10 @@ def parse_args():
6363
help="Path to configuration file",
6464
)
6565
parser.add_argument(
66-
"--with-ray",
67-
action="store_true",
68-
default=False,
69-
help="Launch ray"
66+
"--with-ray", action="store_true", default=False, help="Launch ray"
7067
)
7168
parser.add_argument(
72-
"--with-ray-cluster",
73-
action="store_true",
74-
default=False,
75-
help="Launch ray"
69+
"--with-ray-cluster", action="store_true", default=False, help="Launch ray"
7670
)
7771
parser.add_argument(
7872
"--with-appworld",
@@ -93,10 +87,7 @@ def parse_args():
9387
help="Launch webshop",
9488
)
9589
parser.add_argument(
96-
"--with-bfcl",
97-
action="store_true",
98-
default=False,
99-
help="Launch bfcl"
90+
"--with-bfcl", action="store_true", default=False, help="Launch bfcl"
10091
)
10192
parser.add_argument(
10293
"--with-logview",
@@ -114,7 +105,7 @@ def parse_args():
114105
"--skip-check-avail-gpu",
115106
action="store_true",
116107
default=False,
117-
help="Skip GPU availability check"
108+
help="Skip GPU availability check",
118109
)
119110
parser.add_argument(
120111
"--kill",
@@ -134,7 +125,14 @@ def parse_args():
134125
type=str,
135126
default="",
136127
required=False,
137-
help="Prefix for deepfinance service names"
128+
help="Prefix for deepfinance service names",
129+
)
130+
parser.add_argument(
131+
"--swarm-overwatch",
132+
type=str,
133+
default="",
134+
required=False,
135+
help="Swarm server URL for overwatch monitoring (e.g., http://localhost:10086)",
138136
)
139137
return parser.parse_args()
140138

@@ -143,22 +141,37 @@ def check_model_file_exists(exp_config):
143141
model_path = exp_config["ajet"]["model"]["path"]
144142
# if model_path has more than 2 '/', we consider it as a dir path
145143
if model_path.count("/") > 2:
146-
assert os.path.exists(model_path), f"Model path {model_path} does not exist. Please check your configuration."
144+
assert os.path.exists(model_path), (
145+
f"Model path {model_path} does not exist. Please check your configuration."
146+
)
147147

148148

149149
def start_swarm_server(env, config):
150150
config = dict_to_namespace(config)
151-
assert config.ajet.enable_swarm_mode, \
151+
assert config.ajet.enable_swarm_mode, (
152152
"Please enable_swarm_mode in config to start swarm server."
153-
assert config.ajet.enable_experimental_interchange_server, \
153+
)
154+
assert config.ajet.enable_experimental_interchange_server, (
154155
"Please enable_experimental_interchange_server in config to start swarm server."
155-
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
156+
)
157+
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import (
158+
start_interchange_server,
159+
)
160+
156161
start_interchange_server(config, blocking=True, env=env)
157162

158163

159164
def main():
160165
args = parse_args()
161166

167+
# Handle swarm overwatch mode
168+
if args.swarm_overwatch:
169+
from ajet.utils.swarm_overwatch import start_overwatch
170+
171+
logger.info(f"Starting Swarm Overwatch for server: {args.swarm_overwatch}")
172+
start_overwatch(args.swarm_overwatch, refresh_interval=1.0)
173+
return
174+
162175
# Enforce GPU availability and free memory threshold before proceeding
163176
if not args.skip_check_avail_gpu:
164177
if (args.backbone != "debug") and (not args.kill) and (not args.autokill):
@@ -174,7 +187,9 @@ def main():
174187
logger.info(f"Killing processes matching keyword: {keyword}")
175188
killed_pids = fast_kill_by_keyword_bash(keyword)
176189
if killed_pids:
177-
logger.success(f"Successfully killed processes with PIDs: {killed_pids}")
190+
logger.success(
191+
f"Successfully killed processes with PIDs: {killed_pids}"
192+
)
178193
else:
179194
logger.warning(f"No processes found matching keyword: {keyword}")
180195
if not args.conf:
@@ -192,16 +207,24 @@ def main():
192207
exp_config = None
193208
exp_dir = args.exp_dir or "saved_experiments"
194209
if args.swarm_server and (not args.conf):
195-
args.conf = os.path.abspath(os.path.join(os.path.dirname(__file__), "default_config/ajet_ts_default.yaml"))
196-
assert os.path.exists(args.conf), "Please provide a valid config file for swarm server mode."
210+
args.conf = os.path.abspath(
211+
os.path.join(
212+
os.path.dirname(__file__), "default_config/ajet_ts_default.yaml"
213+
)
214+
)
215+
assert os.path.exists(args.conf), (
216+
"Please provide a valid config file for swarm server mode."
217+
)
197218
if args.conf:
198219
yaml_path = args.conf
199220
(
200221
main_yaml_fp,
201222
exe_exp_base,
202223
exp_name,
203224
exp_config,
204-
) = prepare_experiment_config(yaml_path, exp_dir, args.backbone, storage=(not args.swarm_server))
225+
) = prepare_experiment_config(
226+
yaml_path, exp_dir, args.backbone, storage=(not args.swarm_server)
227+
)
205228

206229
# setup environment variables
207230
env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp)
@@ -211,9 +234,9 @@ def main():
211234
return
212235

213236
if args.with_ray:
214-
assert (
215-
not args.with_ray_cluster
216-
), "Cannot use both --with-ray and --with-ray-cluster simultaneously."
237+
assert not args.with_ray_cluster, (
238+
"Cannot use both --with-ray and --with-ray-cluster simultaneously."
239+
)
217240
start_ray_service(args, env)
218241

219242
if args.with_appworld:
@@ -235,9 +258,9 @@ def main():
235258
launch_logview(exp_name)
236259

237260
if args.with_ray_cluster:
238-
assert (
239-
not args.with_ray
240-
), "Cannot use both --with-ray and --with-ray-cluster simultaneously."
261+
assert not args.with_ray, (
262+
"Cannot use both --with-ray and --with-ray-cluster simultaneously."
263+
)
241264
start_ray_service(args, env, cluster=True)
242265

243266
if args.conf and main_yaml_fp and exe_exp_base and exp_config:

ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,29 @@ def __init__(self, server_url: str):
5454
self.record_episode_expire_time = {}
5555
self.auto_batching_tasks = []
5656

57+
# better logging management
58+
self._last_second_print_buffer: dict[str, float] = {}
59+
60+
def logger_info(self, message):
61+
# logger with de-duplication within 1 second to prevent log flooding
62+
63+
if message in self._last_second_print_buffer.keys():
64+
timestamp = self._last_second_print_buffer
65+
if time.time() - timestamp[message] < 1:
66+
return
67+
else:
68+
self._last_second_print_buffer[message] = time.time()
69+
logger.info(message)
70+
# clean up old records to prevent memory leak
71+
keys_to_delete = [key for key, ts in self._last_second_print_buffer.items() if time.time() - ts > 1]
72+
for key in keys_to_delete:
73+
del self._last_second_print_buffer[key]
74+
else:
75+
self._last_second_print_buffer[message] = time.time()
76+
logger.info(message)
77+
78+
return
79+
5780

5881
def _clean_up_expired_records(self):
5982
# remove records that have expired and expired at least CLEAN_RECORD_TIMEOUT seconds ago
@@ -82,7 +105,7 @@ def begin_episode(self, discard_episode_timeout=60, max_episode_time=120, episod
82105
"""
83106
status, status_json = self.get_engine_status() # warm up connection and log the status
84107
if status not in ["ENGINE.ROLLING"]:
85-
logger.info(f"Engine status is {status}. Waiting until ENGINE.ROLLING...")
108+
self.logger_info(f"Engine status is {status}. Waiting until ENGINE.ROLLING...")
86109
self._wait_until_status_change_to(desired_status="ENGINE.ROLLING", verbose=False)
87110

88111
while True:
@@ -107,7 +130,7 @@ def begin_episode(self, discard_episode_timeout=60, max_episode_time=120, episod
107130
episode_uuid = data.episode_uuid
108131
openai_base_url = data.openai_base_url
109132
openai_api_key = data.openai_api_key
110-
logger.info(f"Claimed episode {episode_uuid}, current global step: {status_json.get('global_step', 'unknown')}")
133+
self.logger_info(f"Claimed episode {episode_uuid}, current global step: {status_json.get('global_step', 'unknown')}")
111134
return episode_uuid, OpenaiBaseUrlAndApiKey(
112135
base_url=openai_base_url,
113136
api_key=openai_api_key,
@@ -121,7 +144,7 @@ def begin_episode(self, discard_episode_timeout=60, max_episode_time=120, episod
121144
]
122145
if any(scenario in data.fail_cause for scenario in need_wait_scenarios):
123146
if time.time() - self.previous_warning_time > 60:
124-
logger.info(f"{data.fail_cause}. Retrying in 15s...")
147+
self.logger_info(f"{data.fail_cause}. Retrying in 15s...")
125148
self.previous_warning_time = time.time()
126149
time.sleep(15)
127150
else:
@@ -169,7 +192,7 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut
169192
data = EndEpisodeResponse.model_validate(resp.json())
170193

171194
if data.success:
172-
logger.info(f"Ended episode {episode_uuid}")
195+
self.logger_info(f"Ended episode {episode_uuid}")
173196
else:
174197
logger.error(f"Failed to end episode {episode_uuid}")
175198
raise RuntimeError(f"Failed to end episode {episode_uuid}")
@@ -198,7 +221,7 @@ def abort_episode(self, episode_uuid: str):
198221
data = EndEpisodeResponse.model_validate(resp.json())
199222

200223
if data.success:
201-
logger.info(f"Aborted episode {episode_uuid}")
224+
self.logger_info(f"Aborted episode {episode_uuid}")
202225
else:
203226
logger.error(f"Failed to end episode {episode_uuid}")
204227

@@ -227,7 +250,7 @@ def sync_train_config(self, agent_jet_job: AgentJetJob):
227250
timeout=GENERAL_TIMEOUT
228251
)
229252
raise_for_status_with_detail(resp)
230-
logger.info("Synced train config to Swarm server")
253+
self.logger_info("Synced train config to Swarm server")
231254
except Exception as e:
232255
logger.error(f"Error syncing train config: {e}")
233256
raise
@@ -252,7 +275,7 @@ def start_engine(self):
252275
raise_for_status_with_detail(resp)
253276
result = resp.json()
254277
if result.get("success"):
255-
logger.info("Successfully started training engine on Swarm server (current model global step)")
278+
self.logger_info("Successfully started training engine on Swarm server (current model global step)")
256279
else:
257280
logger.error("Failed to start training engine")
258281
raise RuntimeError("Failed to start training engine")
@@ -267,7 +290,7 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=
267290
Reports status every 5 seconds while waiting.
268291
"""
269292
if verbose:
270-
logger.info(f"Polling engine status until {desired_status}...")
293+
self.logger_info(f"Polling engine status until {desired_status}...")
271294
last_report_time = time.time()
272295
init_poll_time = last_report_time
273296

@@ -279,13 +302,13 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=
279302
# Report status every 5 seconds
280303
if current_time - last_report_time >= 30:
281304
if verbose:
282-
logger.info(f"Current engine status (already waited {int(current_time - init_poll_time)}s): {current_status}")
305+
self.logger_info(f"Current engine status (already waited {int(current_time - init_poll_time)}s): {current_status}")
283306
last_report_time = current_time
284307

285308
# Check if engine has reached the desired status
286309
if current_status == desired_status:
287310
if verbose:
288-
logger.info(f"Engine status is {desired_status}.")
311+
self.logger_info(f"Engine status is {desired_status}.")
289312
break
290313

291314
# Wait a bit before next poll
@@ -363,15 +386,15 @@ def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob, fo
363386
time.sleep(8)
364387
current_status, _ = self.get_engine_status()
365388
if current_status == "ENGINE.OFFLINE":
366-
logger.info("Engine is OFFLINE. Syncing train config and starting engine...")
389+
self.logger_info("Engine is OFFLINE. Syncing train config and starting engine...")
367390
self.sync_train_config(agent_jet_job)
368391
self.start_engine()
369392
elif current_status == "ENGINE.ROLLING":
370-
logger.info("Engine is already ROLLING. No action needed.")
393+
self.logger_info("Engine is already ROLLING. No action needed.")
371394
elif current_status == "ENGINE.ROLLING_POST":
372-
logger.info("Engine is already ROLLING. No action needed.")
395+
self.logger_info("Engine is already ROLLING. No action needed.")
373396
elif current_status == "ENGINE.BOOTING":
374-
logger.info("Engine is BOOTING. Waiting until it becomes ROLLING...")
397+
self.logger_info("Engine is BOOTING. Waiting until it becomes ROLLING...")
375398
self._wait_until_status_change_to(desired_status="ENGINE.ROLLING")
376399
logger.success("Training engine is now ROLLING and ready.")
377400
elif current_status == "ENGINE.CANNOT_CONNECT":
@@ -388,7 +411,7 @@ def stop_engine(self):
388411
"""
389412
current_status, _ = self.get_engine_status()
390413
if current_status == "ENGINE.OFFLINE":
391-
logger.info("Engine is already OFFLINE. No action needed.")
414+
self.logger_info("Engine is already OFFLINE. No action needed.")
392415
return
393416

394417
resp = httpx.post(
@@ -399,7 +422,7 @@ def stop_engine(self):
399422
raise_for_status_with_detail(resp)
400423
result = resp.json()
401424
if result.get("success"):
402-
logger.info("Successfully stopped training engine on Swarm server")
425+
self.logger_info("Successfully stopped training engine on Swarm server")
403426
else:
404427
logger.error("Failed to stop training engine")
405428
raise RuntimeError("Failed to stop training engine")
@@ -502,5 +525,5 @@ def rollout(task) -> float | None:
502525
if len(episodes) == (remote_batch_size * local_grpo_n):
503526
episode_results = run_episodes_until_all_complete(episodes, func=rollout, auto_retry=True)
504527
for episode, reward in zip(episodes, episode_results):
505-
logger.info(f"Episode for task {episode.task_id} completed with reward: {reward}")
528+
self.logger_info(f"Episode for task {episode.task_id} completed with reward: {reward}")
506529
episodes.clear()

ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ async def end_episode(req: EndEpisodeRequest):
554554

555555
if episode_status != "claimed":
556556
logger.error(f"[server] Episode {episode_uuid} is not in claimed status.")
557-
raise HTTPException(status_code=400, detail=f"Episode {episode_uuid} is not in claimed status, maybe you take **too long** to submit the workflow output, increase `discard_episode_timeout` when `begin_episode`.")
557+
raise HTTPException(status_code=400, detail=f"Episode {episode_uuid} is not in claimed status, maybe you take **too long** to submit the workflow output, try increase `discard_episode_timeout` when `begin_episode`.")
558558

559559
if client_uuid_recorded != client_uuid:
560560
logger.error(f"[server] Episode {episode_uuid} is claimed by different client: {client_uuid_recorded}, but got {client_uuid}.")

0 commit comments

Comments
 (0)