88
99import numpy as np
1010import torch
11+ import threading
1112from loguru import logger
1213from tensordict import TensorDict
1314from torch .nn .utils .rnn import pad_sequence
1415from tqdm import tqdm
1516from verl import DataProto
1617from verl .utils .torch_functional import pad_sequence_to_length
1718
18- from ajet .context_tracker .basic_tracker import BaseContextTracker
1919from ajet .schema .task import Task
2020from ajet .schema .trajectory import Sample
2121from ajet .task_rollout .single_worker import BaseRolloutManager
22+ from ajet .context_tracker .basic_tracker import BaseContextTracker
23+ from ajet .tuner_lib .weight_tuner .experimental .interchange_utils import http_change_engine_status
2224
2325
2426class DynamicRolloutManager (BaseRolloutManager ):
@@ -481,33 +483,39 @@ def rollout_swarm( # noqa: C901
481483 tracker_array : List [BaseContextTracker ] = []
482484 assert mode != "validate"
483485 rollout_n = self .rollout_n
484- n_task = len (tasks )
486+ n_batch_task = len (tasks )
487+ n_task = min (len (tasks ), self .max_parallel // rollout_n )
488+ assert n_task > 0 , f"n_task is not valid, n_task = min(len(tasks), self.max_parallel // rollout_n) = { n_task } "
485489 self .current_token_count_time = time .time ()
486490
487491 # initialize observation window
488492 observation_window : Dict [str , List [int | bool | str ]] = {
489493 "info" : ["" for _ in range (n_task * rollout_n )],
490494 "step" : [0 for _ in range (n_task * rollout_n )],
491495 "stop" : [False for _ in range (n_task * rollout_n )],
496+ "hard_stop" : [False for _ in range (n_task * rollout_n )],
492497 "token" : [0 for _ in range (n_task * rollout_n )],
493498 }
494499 executor = ThreadPoolExecutor (max_workers = self .max_parallel )
495500 futures : List [Future ] = []
496501 completed_task_id_map_ct : Dict [str , List [BaseContextTracker ]] = {}
502+ executor_lock = threading .Lock ()
497503
498504 # submit initial tasks
499505 dummy_task = Task (main_query = "dummy task" )
500506 for task_batch_index in range (n_task ):
501507 for task_rollout_index in range (rollout_n ):
502508 task_thread_index = task_batch_index * rollout_n + task_rollout_index
503509 future = executor .submit (
504- self .rollout_env_worker ,
510+ self .rollout_env_worker_loop ,
505511 task = dummy_task ,
506512 task_tag = "" ,
507513 mode = mode ,
508514 task_batch_index = task_batch_index ,
509515 task_thread_index = task_thread_index ,
510516 observation_window = observation_window ,
517+ completed_task_id_map_ct = completed_task_id_map_ct ,
518+ executor_lock = executor_lock ,
511519 )
512520 observation_window ["info" ][task_thread_index ] = "1"
513521 futures .append (future )
@@ -516,14 +524,15 @@ def enough_sample_stop_condition(completed_task_id_map_ct) -> bool:
516524 n = 0
517525 for ct_list in completed_task_id_map_ct .values ():
518526 n += len (ct_list )
519- return (n >= n_task * rollout_n )
527+ print (f"Current collected samples: { n } , target: { n_batch_task * rollout_n } " )
528+ return (n >= n_batch_task * rollout_n )
520529
521530 def enough_finished_task_stop_condition (completed_task_id_map_ct ) -> bool :
522531 n_finish_roll_task = 0
523532 for ct_list in completed_task_id_map_ct .values ():
524533 if len (ct_list ) >= rollout_n :
525534 n_finish_roll_task += 1
526- return (n_finish_roll_task >= n_task )
535+ return (n_finish_roll_task >= n_batch_task )
527536
528537 def enough_non_dummy_task_stop_condition (completed_task_id_map_ct ) -> bool :
529538 n_finish_roll_task = 0
@@ -535,63 +544,39 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
535544 all_equal = all (x == task_cmd_reward_array [0 ] for x in task_cmd_reward_array )
536545 if all_equal : continue
537546 n_finish_roll_task += 1
538- return (n_finish_roll_task >= n_task )
547+ return (n_finish_roll_task >= n_batch_task )
539548
540549 stop_condition = enough_sample_stop_condition
541550
542- def force_stop_all_threads ():
543- for k in range (len (observation_window ["stop" ])):
544- observation_window ["stop" ][k ] = True
551+ def stop_all_threads_soft ():
552+ for k in range (len (observation_window ["stop" ])): observation_window ["stop" ][k ] = True
553+ http_change_engine_status (self .config , "ENGINE.ROLLING_POST" )
554+ return
555+
556+ def stop_all_threads_hard ():
557+ for k in range (len (observation_window ["hard_stop" ])): observation_window ["hard_stop" ][k ] = True
558+ http_change_engine_status (self .config , "ENGINE.WEIGHT_SYNCING" )
545559 return
546560
547- tic = time . time ()
561+ cnt = 0
548562 while True :
549- # wait for a completed task
550- done_arr , pending_arr = wait (futures , timeout = 10 , return_when = FIRST_COMPLETED )
551- print (f"Done tasks: { len (done_arr )} , Pending tasks: { len (pending_arr )} " )
552- toc = time .time ()
553- if (toc - tic ) > 8 :
554- tic = toc
563+ cnt += 1
564+ time .sleep (2 )
565+ if (cnt % 5 == 0 ):
555566 self .step_status_printer (observation_window )
556- # get result
557- for future in done_arr :
558- ct : BaseContextTracker = future .result ()
559- if ct .task_id not in completed_task_id_map_ct :
560- completed_task_id_map_ct [ct .task_id ] = [ct ]
561- else :
562- completed_task_id_map_ct [ct .task_id ] += [ct ]
563- # if meet stop condition
564567 meet_stop_condition_after_new_results = stop_condition (completed_task_id_map_ct )
565568 if meet_stop_condition_after_new_results :
566- force_stop_all_threads ()
569+ print ("Sending soft stop signal to all threads..." )
570+ stop_all_threads_soft ()
567571 break
568- else :
569- # re-spawn new tasks for done futures
570- for task_batch_index in range (n_task ):
571- for task_rollout_index in range (rollout_n ):
572- task_thread_index = task_batch_index * rollout_n + task_rollout_index
573- has_done = (futures [task_thread_index ] in done_arr )
574-
575- observation_window ["info" ][task_thread_index ] = str (int (observation_window ["info" ][task_thread_index ]) + 1 )
576- observation_window ["stop" ][task_thread_index ] = False
577- observation_window ["step" ][task_thread_index ] = 0
578-
579- if has_done :
580- print (f"Re-spawning thread { task_thread_index } ..." )
581- future = executor .submit (
582- self .rollout_env_worker ,
583- task = dummy_task ,
584- task_tag = "" ,
585- mode = mode ,
586- task_batch_index = task_batch_index ,
587- task_thread_index = task_thread_index ,
588- observation_window = observation_window ,
589- )
590- futures [task_thread_index ] = future
591572
592573 # wait for all threads to complete
593574 print ('Finalizing all threads...' )
594- wait (futures , return_when = ALL_COMPLETED )
575+ executor .shutdown (wait = True )
576+
577+ # stop all threads hard
578+ print ("Sending hard stop signal to all threads..." )
579+ stop_all_threads_hard ()
595580
596581 # build tracker_array
597582 print ('Collecting results...' )
0 commit comments