@@ -74,6 +74,7 @@ def __init__(
7474 swarm_url : str ,
7575 project_name : str = DEFAULT_PROJECT_NAME ,
7676 resolved_yaml_path : str | None = None ,
77+ prepare_only : bool = False ,
7778 max_prompt_length : int = 3000 ,
7879 max_response_length : int = 15000 ,
7980 max_model_len : int = 18000 ,
@@ -96,6 +97,7 @@ def __init__(
9697 self .result_dir = result_dir
9798 self .project_name = project_name
9899 self .resolved_yaml_path = resolved_yaml_path or os .path .join (result_dir , "resolved_swarm_config.yaml" )
100+ self .prepare_only = prepare_only
99101 self .max_prompt_length = max_prompt_length
100102 self .max_response_length = max_response_length
101103 self .max_model_len = max_model_len
@@ -105,7 +107,7 @@ def __init__(
105107 data_dir = os .path .join (os .path .dirname (__file__ ), ".." , "data" )
106108 self .train_dataset = os .path .join (data_dir , "dapo-math-17k.parquet" )
107109 self .test_datasets = {
108- "AIME-2024" : os .path .join (data_dir , "aime-2024.parquet" ),
110+ # "AIME-2024": os.path.join(data_dir, "aime-2024.parquet"),
109111 "AIME-2025" : os .path .join (data_dir , "aime-2025.parquet" ),
110112 "AIME-2026" : os .path .join (data_dir , "aime-2026.parquet" ),
111113 "DAPO-Math-Tiny-Val" : os .path .join (data_dir , "dapo-math-tiny-val.parquet" ),
@@ -177,6 +179,9 @@ def setup(self):
177179
178180 self .ajet_job .dump_job_as_yaml (self .resolved_yaml_path )
179181
182+ if self .prepare_only :
183+ return
184+
180185 self .dataset = RouterTaskReader (
181186 reader_type = "huggingface_dat_repo" ,
182187 reader_config = AjetTaskReader (
@@ -191,7 +196,7 @@ def setup(self):
191196 )
192197
193198 eval_downloaders = {
194- "AIME-2024" : download_data .ensure_aime_2024 ,
199+ # "AIME-2024": download_data.ensure_aime_2024,
195200 "AIME-2025" : download_data .ensure_aime_2025 ,
196201 "AIME-2026" : download_data .ensure_aime_2026 ,
197202 }
@@ -296,9 +301,10 @@ def _run_eval_one(self, n_global_step: int, label: str, eval_tasks: list, eval_l
296301
297302 def train (self ):
298303 assert self .swarm_worker is not None and self .dataset is not None , "setup() must be called before train()"
304+
305+ last_eval_step = 0
299306 self .run_eval (0 )
300307
301- task_count = 0
302308 max_parallel = 64
303309 executor = TaskCountLimitedThreadPoolExecutor (
304310 max_parallel_groups = self .batch_size ,
@@ -308,18 +314,17 @@ def train(self):
308314 self .swarm_worker .add_entering_weight_sync_callback (executor .on_entering_weight_sync )
309315
310316 num_epochs = 10000
311- n_global_step = 0
312317 for epoch in range (num_epochs ):
313318 for _ , task in enumerate (self .dataset .generate_training_tasks ()):
314319 args_list = [{"task" : task } for _ in range (self .grpo_n )]
315320 executor .submit_group (task_id = task .task_id , fn = self .rollout , args_list = args_list )
316321
317- task_count += 1
322+ n_global_step = self . swarm_worker . get_global_step ()
318323
319- time_to_eval = task_count % (self .eval_interval * self .batch_size ) == 0
320- n_global_step = task_count // self .batch_size
324+ time_to_eval = n_global_step >= last_eval_step + self .eval_interval
321325 if time_to_eval :
322326 self .run_eval (n_global_step )
327+ last_eval_step = n_global_step
323328
324329 if n_global_step >= self .total_training_steps :
325330 break
@@ -335,6 +340,8 @@ def train(self):
335340
336341 def run (self ):
337342 self .setup ()
343+ if self .prepare_only :
344+ return
338345 self .train ()
339346
340347
@@ -371,7 +378,7 @@ def main():
371378 help = "Evaluate every N global steps" )
372379 parser .add_argument ("--eval-k" , type = int , default = 4 ,
373380 help = "Number of rollouts per eval task (pass@k)" )
374- parser .add_argument ("--grpo-repeat" , type = int , default = 8 ,
381+ parser .add_argument ("--grpo-repeat" , type = int , default = 4 ,
375382 help = "GRPO num_repeat per training task" )
376383 parser .add_argument ("--ppo-epochs" , type = int , default = 1 ,
377384 help = "Number of PPO epochs per update" )
@@ -397,6 +404,7 @@ def main():
397404 swarm_url = args .swarm_url ,
398405 project_name = args .project_name ,
399406 resolved_yaml_path = args .resolved_yaml_path ,
407+ prepare_only = args .prepare_only ,
400408 max_prompt_length = args .max_prompt_length ,
401409 max_response_length = args .max_response_length ,
402410 max_model_len = args .max_model_len ,
0 commit comments