Skip to content

Commit 1604076

Browse files
committed
add stop engine return argument
1 parent 094673e commit 1604076

3 files changed

Lines changed: 10 additions & 4 deletions

File tree

ajet/tuner_lib/experimental/as_swarm_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,10 @@ def auto_sync_train_config_and_start_engine(self, agent_jet_job: AgentJetJob, fo
544544
self.logger_info("Engine is already ROLLING. No action needed.")
545545
elif current_status == "ENGINE.ROLLING_POST":
546546
self.logger_info("Engine is already ROLLING. No action needed.")
547-
elif current_status in ["ENGINE.BOOTING", "ENGINE.CANNOT_CONNECT", "ENGINE.WEIGHT_SYNCING"]:
547+
elif current_status in ["ENGINE.CANNOT_CONNECT"]:
548+
logger.error("Unable to connect to swarm server.")
549+
raise RuntimeError(f"Unable to connect to swarm server.")
550+
elif current_status in ["ENGINE.BOOTING", "ENGINE.WEIGHT_SYNCING"]:
548551
self.logger_info(f"Engine is {current_status}. Waiting until it becomes ROLLING...")
549552
self._wait_until_status_change_to(desired_status="ENGINE.ROLLING")
550553
logger.success("Training engine is now ROLLING and ready.")
@@ -568,7 +571,7 @@ def stop_engine(self):
568571
)
569572
raise_for_status_with_detail(resp)
570573
result = resp.json()
571-
if result.get("success"):
574+
if result and result.get("success"):
572575
self.logger_info("Successfully stopped training engine on Swarm server")
573576
else:
574577
logger.error("Failed to stop training engine")

ajet/tuner_lib/experimental/as_swarm_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,5 +778,6 @@ async def stop_engine():
778778
- Clean up shared memory state
779779
"""
780780
kill_process_tree(shared_mem_dict_lock, shared_mem_dict)
781+
return BoolResponse(success=True)
781782

782783
return app, register_episode_ready_listener()

tutorial/example_math_swarm/math.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
REMOTE_BATCH_SIZE = 32
2424
REMOTE_ALLOCATE_GPU_PER_NODE = 8
25-
REMOTE_TRAIN_MODEL = '/root/agentjet/modelscope_cache/Qwen/Qwen2.5-7B-Instruct'
25+
# REMOTE_TRAIN_MODEL = '/root/agentjet/modelscope_cache/Qwen/Qwen2.5-7B-Instruct'
26+
REMOTE_TRAIN_MODEL = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct'
2627

2728
def main():
2829

@@ -48,7 +49,8 @@ def main():
4849
model=REMOTE_TRAIN_MODEL,
4950
batch_size=REMOTE_BATCH_SIZE,
5051
num_repeat=GRPO_N,
51-
)
52+
),
53+
force_restart=True,
5254
)
5355

5456
def rollout(task):

0 commit comments

Comments
 (0)