@@ -101,6 +101,9 @@ class QueueManager(BaseManager):
101101 self .finish_request_barrier = [
102102 threading .Barrier (self .num_client ) for _ in range (self .local_data_parallel_size )
103103 ]
104+ self .worker_process_tp_barrier = [
105+ threading .Barrier (self .num_client ) for _ in range (self .local_data_parallel_size )
106+ ]
104107
105108 self .finish_add_cache_task_barrier = [
106109 threading .Barrier (self .num_client ) for _ in range (self .local_data_parallel_size )
@@ -193,6 +196,10 @@ class QueueManager(BaseManager):
193196 "get_finish_add_cache_task_barrier" ,
194197 callable = lambda idx : self .finish_add_cache_task_barrier [idx ],
195198 )
199+ QueueManager .register (
200+ "get_worker_process_tp_barrier" ,
201+ callable = lambda idx : self .worker_process_tp_barrier [idx ],
202+ )
196203 self .manager : BaseManager = QueueManager (address = self .address , authkey = self .authkey )
197204 self .manager .start ()
198205 else :
@@ -217,6 +224,7 @@ class QueueManager(BaseManager):
217224 QueueManager .register ("get_connect_rdma_tasks" )
218225 QueueManager .register ("get_connect_rdma_tasks_responses" )
219226 QueueManager .register ("get_connect_task_lock" )
227+ QueueManager .register ("get_worker_process_tp_barrier" )
220228 self .manager = QueueManager (address = self .address , authkey = self .authkey )
221229 self ._connect_with_retry ()
222230
@@ -239,6 +247,7 @@ class QueueManager(BaseManager):
239247 self .finish_add_cache_task_barrier = self .manager .get_finish_add_cache_task_barrier (
240248 self .local_data_parallel_id
241249 )
250+ self .worker_process_tp_barrier = self .manager .get_worker_process_tp_barrier (self .local_data_parallel_id )
242251 self .finished_req_queue = self .manager .get_finish_request_queue (self .local_data_parallel_id )
243252 self .finished_add_cache_task_queue = self .manager .get_finish_add_cache_task_queue (
244253 self .local_data_parallel_id
0 commit comments