@@ -78,12 +78,10 @@ def step_status_printer(self, observation_window):
7878 print_buf += [f"[finished]:{ count } threads" ]
7979 print (f"Rollout progress ({ token_gen_per_sec_str } ): " + " // " .join (print_buf ))
8080
81- if DEBUG :
82- self ._write_swarm_rollout_dynamic_log (observation_window )
83-
8481
8582 def _write_swarm_rollout_dynamic_log (self , observation_window ):
86- fp = "./swarm_rollout.dynamic.log"
83+ base_exp_dir = self .config .ajet .experiment_dir # {exp-dir}/{experiment_name}
84+ fp = f"{ base_exp_dir } /swarm_rollout.dynamic.log"
8785 string_buffer = ""
8886 for info in observation_window ["info" ]:
8987 string_buffer += f"{ info } \n "
@@ -189,25 +187,6 @@ def rollout_swarm( # noqa: C901
189187 completed_task_id_map_ct : Dict [str , List [BaseContextTracker ]] = {}
190188 executor_lock = threading .Lock ()
191189
192- # submit initial tasks
193- dummy_task = Task (main_query = "dummy task" )
194- for task_batch_index in range (n_task ):
195- for task_rollout_index in range (rollout_n ):
196- task_thread_index = task_batch_index * rollout_n + task_rollout_index
197- observation_window ["info" ][task_thread_index ] = f"\n \n \n \n [thread { task_thread_index } submit]\n "
198- future = executor .submit (
199- self .rollout_env_worker_loop ,
200- task = dummy_task ,
201- task_tag = "" ,
202- mode = mode ,
203- task_batch_index = task_batch_index ,
204- task_thread_index = task_thread_index ,
205- observation_window = observation_window ,
206- completed_task_id_map_ct = completed_task_id_map_ct ,
207- executor_lock = executor_lock ,
208- )
209- futures .append (future )
210-
211190 def count_tasks (completed_task_id_map_ct ):
212191 total_completed_episodes = 0
213192 total_completed_tasks = 0
@@ -239,7 +218,7 @@ def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool:
239218 counts = count_tasks (completed_task_id_map_ct )
240219 total_completed_episodes = counts ["total_completed_episodes" ]
241220 total_completed_tasks = counts ["total_completed_tasks" ]
242- if total_completed_episodes > self .config .ajet .swarm_mode_sample_collection_max_cached_episodes // 2 :
221+ if total_completed_episodes > ( self .config .ajet .swarm_mode_sample_collection_max_cached_episodes // 5 * 4 ) :
243222 logger .warning (
244223 f"Total cached episodes [{ total_completed_episodes } ] is going to exceed the max cached episodes [{ self .config .ajet .swarm_mode_sample_collection_max_cached_episodes } ], "
245224 f"but we are still not able to meet the stop condition (current finished tasks [{ total_completed_tasks } ], target tasks [{ n_batch_task } ]), this may cause memory issues. "
@@ -261,7 +240,7 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
261240 total_completed_episodes = counts ["total_completed_episodes" ]
262241 total_completed_tasks = counts ["total_completed_tasks" ]
263242 total_completed_non_dummy_tasks = counts ["total_completed_non_dummy_tasks" ]
264- if total_completed_episodes > self .config .ajet .swarm_mode_sample_collection_max_cached_episodes // 2 :
243+ if total_completed_episodes > ( self .config .ajet .swarm_mode_sample_collection_max_cached_episodes // 5 * 4 ) :
265244 logger .warning (
266245 f"Total cached episodes [{ total_completed_episodes } ] is going to exceed the max cached episodes [{ self .config .ajet .swarm_mode_sample_collection_max_cached_episodes } ], "
267246 f"but we are still not able to meet the stop condition (current finished tasks [{ total_completed_non_dummy_tasks } ], target tasks [{ n_batch_task } ]), this may cause memory issues. "
@@ -291,6 +270,9 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
291270 logger .error (f"Invalid swarm_mode_sample_collection_method: { self .config .ajet .swarm_mode_sample_collection_method } , fallback to default method: rollout_until_finish_enough_tasks" )
292271 stop_condition = enough_finished_task_stop_condition
293272
273+ def is_already_soft_stopped ():
274+ return all (observation_window ["stop" ])
275+
294276 # communicate with interchange server to stop new episode, and let threads finish current episode, then collect results and shutdown executor
295277 def stop_all_threads_soft ():
296278 for k in range (len (observation_window ["stop" ])): observation_window ["stop" ][k ] = True
@@ -303,6 +285,34 @@ def stop_all_threads_hard():
303285 http_change_engine_status (self .config , "ENGINE.WEIGHT_SYNCING" )
304286 return
305287
288+ # pass a stop condition callback function to each thread, so that threads can check the stop condition whenever it finishes a cycle, this is faster than polling
289+ def stop_condition_callback (completed_task_id_map_ct ):
290+ if stop_condition (completed_task_id_map_ct ):
291+ if not is_already_soft_stopped ():
292+ stop_all_threads_soft ()
293+ return True
294+ return False
295+
296+ # submit initial tasks
297+ dummy_task = Task (main_query = "dummy task" )
298+ for task_batch_index in range (n_task ):
299+ for task_rollout_index in range (rollout_n ):
300+ task_thread_index = task_batch_index * rollout_n + task_rollout_index
301+ observation_window ["info" ][task_thread_index ] = f"\n \n \n \n [thread { task_thread_index } submit]\n "
302+ future = executor .submit (
303+ self .rollout_env_worker_loop ,
304+ task = dummy_task ,
305+ task_tag = "" ,
306+ mode = mode ,
307+ task_batch_index = task_batch_index ,
308+ task_thread_index = task_thread_index ,
309+ observation_window = observation_window ,
310+ completed_task_id_map_ct = completed_task_id_map_ct ,
311+ executor_lock = executor_lock ,
312+ stop_condition_callback = stop_condition_callback ,
313+ )
314+ futures .append (future )
315+
306316 def update_rollout_result_array_preview (observation_window , completed_task_id_map_ct : Dict [str , List [BaseContextTracker ]]):
307317 buffer = ""
308318 completed_tasks_details = {}
@@ -334,17 +344,19 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
334344 completed_tasks_details = completed_tasks_details ,
335345 )
336346 http_update_rollout_pool_information (self .config , pool_info )
337-
338347 return
339348
340349 # loop and wait until stop condition is met, then stop threads and collect results
350+ CHECK_STATUS_INTERVAL = 4 # seconds
351+ PRINT_STATUS_INTERVAL = 12 # seconds
341352 cnt = 0
342353 while True :
343354 cnt += 1
344- time .sleep (2 )
345- if (cnt % 5 == 0 ):
355+ time .sleep (CHECK_STATUS_INTERVAL )
356+ if (cnt % ( PRINT_STATUS_INTERVAL // CHECK_STATUS_INTERVAL ) == 0 ):
346357 update_rollout_result_array_preview (observation_window , completed_task_id_map_ct )
347358 self .step_status_printer (observation_window )
359+ self ._write_swarm_rollout_dynamic_log (observation_window )
348360 meet_stop_condition_after_new_results = stop_condition (completed_task_id_map_ct )
349361 if meet_stop_condition_after_new_results :
350362 print ("Sending soft stop signal to all threads..." )
@@ -377,9 +389,8 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
377389 tracker .current_batch_reward = float (task_scalar_reward )
378390
379391 # for debugging
380- if DEBUG :
381- update_rollout_result_array_preview (observation_window , completed_task_id_map_ct )
382- self ._write_swarm_rollout_dynamic_log (observation_window )
392+ update_rollout_result_array_preview (observation_window , completed_task_id_map_ct )
393+ self ._write_swarm_rollout_dynamic_log (observation_window )
383394
384395 return tracker_array
385396
@@ -428,9 +439,9 @@ def trajectories_to_samples(self, tracker_array: List[BaseContextTracker]) -> Li
428439 try :
429440 sample_arr = tracker .group_tokenize ()
430441 except Exception as e :
442+ logger .bind (exception = True ).exception ("Error during tracker.group_tokenize()" )
431443 raise e
432444 finally :
433- logger .bind (exception = True ).exception ("Error during tracker.tokenize()" ) # for debugging
434445 tracker .generate_log (global_step = self .current_global_steps )
435446 if os .environ .get ("BEST_LOGGER_PATH" , None ) and os .environ .get (
436447 "AJET_DEBUG" , None
0 commit comments