Skip to content

Commit 2f83609

Browse files
committed
2 parents 649d742 + 32de692 commit 2f83609

File tree

4 files changed

+107
-22
lines changed

4 files changed

+107
-22
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,5 @@ dataset_gsm8k/*
167167
benchmark_datasets
168168
modelscope_cache
169169
prompts
170+
swarmexp
171+
swarmlog

ajet/tuner_lib/experimental/as_swarm_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,8 @@ def override_param_callback(config):
377377
from ajet.utils.launch_utils import start_ray_service
378378

379379
logger.info("[start_engine] Starting Ray service...")
380-
start_ray_service(args, env)
380+
# start_ray_service(args, env)
381+
await asyncio.to_thread(start_ray_service, args, env) # start ray in separate thread to avoid blocking
381382
else:
382383
logger.info("[start_engine] Ray already initialized")
383384

tutorial/example_frozenlake_swarm/frozen_lake_roll.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,13 @@
1414
# --------- configurations that take effect locally -------------
1515
LOCAL_GRPO_N = 4 # grpo group size
1616
LOCAL_NUM_EPOCH = 10000
17-
LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/agentjet/agentjet/tmp/arxiv_papers/train.parquet"
1817

1918
# --------- configurations that take effect remotely -------------
2019
REMOTE_BATCH_SIZE = 32
2120
REMOTE_1_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url
2221
REMOTE_1_ALLOCATE_GPU_PER_NODE = 4
2322
REMOTE_1_TRAIN_MODEL = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct'
24-
REMOTE_2_SWARM_URL = "http://localhost:10087" # Change to your swarm remote url
25-
REMOTE_2_ALLOCATE_GPU_PER_NODE = 4
26-
REMOTE_2_TRAIN_MODEL = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct'
23+
2724

2825
class WeightUpdatedHalfway(Exception):
2926
"""Raised when the remote side starts updating model weights halfway through an episode."""
@@ -46,22 +43,10 @@ def main():
4643
num_repeat=LOCAL_GRPO_N,
4744
),
4845
)
49-
# Hand shake with remote swarm server
50-
swarm_worker_3B = SwarmClient(REMOTE_2_SWARM_URL)
51-
swarm_worker_3B.auto_sync_train_config_and_start_engine(
52-
AgentJetJob(
53-
algorithm="grpo",
54-
project_name="ajet-swarm",
55-
experiment_name="test2",
56-
n_gpu=REMOTE_2_ALLOCATE_GPU_PER_NODE,
57-
model=REMOTE_2_TRAIN_MODEL,
58-
batch_size=REMOTE_BATCH_SIZE,
59-
num_repeat=LOCAL_GRPO_N,
60-
),
61-
)
46+
6247
def play_different_swarm_server(task, swarm_worker:SwarmClient) -> float | None:
6348
# begin episode
64-
episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=120, max_episode_time=240)
49+
episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=120)
6550
# execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
6651
env = FrozenLake(
6752
env_max_steps=20,
@@ -78,10 +63,7 @@ def play_different_swarm_server(task, swarm_worker:SwarmClient) -> float | None:
7863
def rollout(task):
7964
f1 = threading.Thread(target=play_different_swarm_server, args=(task, swarm_worker_7B), daemon=True)
8065
f1.start()
81-
f2 = threading.Thread(target=play_different_swarm_server, args=(task, swarm_worker_3B), daemon=True)
82-
f2.start()
8366
f1.join()
84-
f2.join()
8567
return
8668

8769

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from ajet.copilot.job import AgentJetJob
2+
from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, run_episodes_until_all_complete
3+
from ajet.default_config.ajet_default import AjetTaskReader
4+
from ajet.task_reader import RouterTaskReader
5+
from .frozenlake import FrozenLake
6+
7+
import asyncio
8+
import threading
9+
10+
# step 1: ajet-swarm start --swarm-port=10086
11+
# step 2: ajet-swarm start --swarm-port=10087
12+
# step 3: python -m tutorial.example_frozenlake_swarm.frozen_lake_roll
13+
14+
# --------- configurations that take effect locally -------------
15+
LOCAL_GRPO_N = 4 # grpo group size
16+
LOCAL_NUM_EPOCH = 10000
17+
18+
# --------- configurations that take effect remotely -------------
19+
REMOTE_BATCH_SIZE = 32
20+
REMOTE_1_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url
21+
REMOTE_1_ALLOCATE_GPU_PER_NODE = 4
22+
REMOTE_1_TRAIN_MODEL = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct'
23+
REMOTE_2_SWARM_URL = "http://localhost:10087" # Change to your swarm remote url
24+
REMOTE_2_ALLOCATE_GPU_PER_NODE = 4
25+
REMOTE_2_TRAIN_MODEL = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct'
26+
27+
class WeightUpdatedHalfway(Exception):
28+
"""Raised when the remote side starts updating model weights halfway through an episode."""
29+
30+
31+
def main():
32+
33+
dataset = RouterTaskReader(reader_type = "random_dummy", reader_config = AjetTaskReader())
34+
35+
# Hand shake with remote swarm server
36+
swarm_worker_7B = SwarmClient(REMOTE_1_SWARM_URL)
37+
swarm_worker_7B.auto_sync_train_config_and_start_engine(
38+
AgentJetJob(
39+
algorithm="grpo",
40+
project_name="ajet-swarm",
41+
experiment_name="test",
42+
n_gpu=REMOTE_1_ALLOCATE_GPU_PER_NODE,
43+
model=REMOTE_1_TRAIN_MODEL,
44+
batch_size=REMOTE_BATCH_SIZE,
45+
num_repeat=LOCAL_GRPO_N,
46+
),
47+
)
48+
# Hand shake with remote swarm server
49+
swarm_worker_3B = SwarmClient(REMOTE_2_SWARM_URL)
50+
swarm_worker_3B.auto_sync_train_config_and_start_engine(
51+
AgentJetJob(
52+
algorithm="grpo",
53+
project_name="ajet-swarm",
54+
experiment_name="test2",
55+
n_gpu=REMOTE_2_ALLOCATE_GPU_PER_NODE,
56+
model=REMOTE_2_TRAIN_MODEL,
57+
batch_size=REMOTE_BATCH_SIZE,
58+
num_repeat=LOCAL_GRPO_N,
59+
),
60+
)
61+
def play_different_swarm_server(task, swarm_worker:SwarmClient) -> float | None:
62+
# begin episode
63+
episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=120)
64+
# execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
65+
env = FrozenLake(
66+
env_max_steps=20,
67+
agent_max_steps=20,
68+
seed=task.metadata["random_number"],
69+
)
70+
workflow_output = asyncio.run(env.execute(task, api_baseurl_key.api_key, api_baseurl_key.base_url))
71+
# report output back to swarm remote
72+
swarm_worker.end_episode(task, episode_uuid, workflow_output)
73+
# print global rollout status across the swarm
74+
swarm_worker.print_rollout_stat()
75+
return workflow_output.reward
76+
77+
def rollout(task):
78+
f1 = threading.Thread(target=play_different_swarm_server, args=(task, swarm_worker_7B), daemon=True)
79+
f1.start()
80+
f2 = threading.Thread(target=play_different_swarm_server, args=(task, swarm_worker_3B), daemon=True)
81+
f2.start()
82+
f1.join()
83+
f2.join()
84+
return
85+
86+
87+
next_batch = []
88+
for epoch in range(LOCAL_NUM_EPOCH):
89+
for _, task in enumerate(dataset.generate_training_tasks()):
90+
for _ in range(LOCAL_GRPO_N):
91+
next_batch.append(task)
92+
if len(next_batch) >= (REMOTE_BATCH_SIZE * LOCAL_GRPO_N):
93+
# wait until getting `local_batching_size` next_batch, then execute them with with retry logic
94+
run_episodes_until_all_complete(next_batch, func=rollout, auto_retry=True)
95+
next_batch.clear()
96+
return None
97+
98+
99+
if __name__ == "__main__":
100+
main()

0 commit comments

Comments
 (0)