Skip to content

Commit 812d96c

Browse files
committed
add get_global_step swarm api
1 parent 5f25de0 commit 812d96c

3 files changed

Lines changed: 26 additions & 14 deletions

File tree

ajet/tuner_lib/experimental/swarm_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ def get_engine_status(self) -> Tuple[str, dict]:
166166
self._engine_status_ready.wait(timeout=15)
167167
return self._engine_status_cache or ("ENGINE.CANNOT_CONNECT", {})
168168

169+
def get_global_step(self) -> int:
170+
"""Return the current global training step from the swarm server."""
171+
_, status_json = self.get_engine_status()
172+
return status_json.get("global_step", 0)
173+
169174
def _engine_status_poll_loop(self):
170175
"""Background thread: fetch engine status at _engine_status_poll_interval."""
171176
while not self._engine_status_poll_stop.is_set():

tutorial/opencode_build_aime/auto_research/auto_train.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(
102102
data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
103103
self.train_dataset = os.path.join(data_dir, "dapo-math-17k.parquet")
104104
self.test_datasets = {
105-
"AIME-2024": os.path.join(data_dir, "aime-2024.parquet"),
105+
# "AIME-2024": os.path.join(data_dir, "aime-2024.parquet"),
106106
"AIME-2025": os.path.join(data_dir, "aime-2025.parquet"),
107107
"AIME-2026": os.path.join(data_dir, "aime-2026.parquet"),
108108
"DAPO-Math-Tiny-Val": os.path.join(data_dir, "dapo-math-tiny-val.parquet"),
@@ -292,7 +292,6 @@ def train(self):
292292
assert self.swarm_worker is not None and self.dataset is not None, "setup() must be called before train()"
293293
self.run_eval(0)
294294

295-
task_count = 0
296295
max_parallel = 64
297296
executor = TaskCountLimitedThreadPoolExecutor(
298297
max_parallel_groups=self.batch_size,
@@ -302,18 +301,18 @@ def train(self):
302301
self.swarm_worker.add_entering_weight_sync_callback(executor.on_entering_weight_sync)
303302

304303
num_epochs = 10000
305-
n_global_step = 0
304+
last_eval_step = 0
306305
for epoch in range(num_epochs):
307306
for _, task in enumerate(self.dataset.generate_training_tasks()):
308307
args_list = [{"task": task} for _ in range(self.grpo_n)]
309308
executor.submit_group(task_id=task.task_id, fn=self.rollout, args_list=args_list)
310309

311-
task_count += 1
310+
n_global_step = self.swarm_worker.get_global_step()
312311

313-
time_to_eval = task_count % (self.eval_interval * self.batch_size) == 0
314-
n_global_step = task_count // self.batch_size
312+
time_to_eval = n_global_step >= last_eval_step + self.eval_interval
315313
if time_to_eval:
316314
self.run_eval(n_global_step)
315+
last_eval_step = n_global_step
317316

318317
if n_global_step >= self.total_training_steps:
319318
break

tutorial/opencode_build_aime/auto_research/auto_train_kl.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
swarm_url: str,
7575
project_name: str = DEFAULT_PROJECT_NAME,
7676
resolved_yaml_path: str | None = None,
77+
prepare_only: bool = False,
7778
max_prompt_length: int = 3000,
7879
max_response_length: int = 15000,
7980
max_model_len: int = 18000,
@@ -96,6 +97,7 @@ def __init__(
9697
self.result_dir = result_dir
9798
self.project_name = project_name
9899
self.resolved_yaml_path = resolved_yaml_path or os.path.join(result_dir, "resolved_swarm_config.yaml")
100+
self.prepare_only = prepare_only
99101
self.max_prompt_length = max_prompt_length
100102
self.max_response_length = max_response_length
101103
self.max_model_len = max_model_len
@@ -105,7 +107,7 @@ def __init__(
105107
data_dir = os.path.join(os.path.dirname(__file__), "..", "data")
106108
self.train_dataset = os.path.join(data_dir, "dapo-math-17k.parquet")
107109
self.test_datasets = {
108-
"AIME-2024": os.path.join(data_dir, "aime-2024.parquet"),
110+
# "AIME-2024": os.path.join(data_dir, "aime-2024.parquet"),
109111
"AIME-2025": os.path.join(data_dir, "aime-2025.parquet"),
110112
"AIME-2026": os.path.join(data_dir, "aime-2026.parquet"),
111113
"DAPO-Math-Tiny-Val": os.path.join(data_dir, "dapo-math-tiny-val.parquet"),
@@ -177,6 +179,9 @@ def setup(self):
177179

178180
self.ajet_job.dump_job_as_yaml(self.resolved_yaml_path)
179181

182+
if self.prepare_only:
183+
return
184+
180185
self.dataset = RouterTaskReader(
181186
reader_type="huggingface_dat_repo",
182187
reader_config=AjetTaskReader(
@@ -191,7 +196,7 @@ def setup(self):
191196
)
192197

193198
eval_downloaders = {
194-
"AIME-2024": download_data.ensure_aime_2024,
199+
# "AIME-2024": download_data.ensure_aime_2024,
195200
"AIME-2025": download_data.ensure_aime_2025,
196201
"AIME-2026": download_data.ensure_aime_2026,
197202
}
@@ -296,9 +301,10 @@ def _run_eval_one(self, n_global_step: int, label: str, eval_tasks: list, eval_l
296301

297302
def train(self):
298303
assert self.swarm_worker is not None and self.dataset is not None, "setup() must be called before train()"
304+
305+
last_eval_step = 0
299306
self.run_eval(0)
300307

301-
task_count = 0
302308
max_parallel = 64
303309
executor = TaskCountLimitedThreadPoolExecutor(
304310
max_parallel_groups=self.batch_size,
@@ -308,18 +314,17 @@ def train(self):
308314
self.swarm_worker.add_entering_weight_sync_callback(executor.on_entering_weight_sync)
309315

310316
num_epochs = 10000
311-
n_global_step = 0
312317
for epoch in range(num_epochs):
313318
for _, task in enumerate(self.dataset.generate_training_tasks()):
314319
args_list = [{"task": task} for _ in range(self.grpo_n)]
315320
executor.submit_group(task_id=task.task_id, fn=self.rollout, args_list=args_list)
316321

317-
task_count += 1
322+
n_global_step = self.swarm_worker.get_global_step()
318323

319-
time_to_eval = task_count % (self.eval_interval * self.batch_size) == 0
320-
n_global_step = task_count // self.batch_size
324+
time_to_eval = n_global_step >= last_eval_step + self.eval_interval
321325
if time_to_eval:
322326
self.run_eval(n_global_step)
327+
last_eval_step = n_global_step
323328

324329
if n_global_step >= self.total_training_steps:
325330
break
@@ -335,6 +340,8 @@ def train(self):
335340

336341
def run(self):
337342
self.setup()
343+
if self.prepare_only:
344+
return
338345
self.train()
339346

340347

@@ -371,7 +378,7 @@ def main():
371378
help="Evaluate every N global steps")
372379
parser.add_argument("--eval-k", type=int, default=4,
373380
help="Number of rollouts per eval task (pass@k)")
374-
parser.add_argument("--grpo-repeat", type=int, default=8,
381+
parser.add_argument("--grpo-repeat", type=int, default=4,
375382
help="GRPO num_repeat per training task")
376383
parser.add_argument("--ppo-epochs", type=int, default=1,
377384
help="Number of PPO epochs per update")
@@ -397,6 +404,7 @@ def main():
397404
swarm_url=args.swarm_url,
398405
project_name=args.project_name,
399406
resolved_yaml_path=args.resolved_yaml_path,
407+
prepare_only=args.prepare_only,
400408
max_prompt_length=args.max_prompt_length,
401409
max_response_length=args.max_response_length,
402410
max_model_len=args.max_model_len,

0 commit comments

Comments
 (0)