@@ -231,47 +231,39 @@ class DistributedHelper:
231231
232232 def __init__ (self , device_mesh : DeviceMesh | None ) -> None :
233233 self .device_mesh = device_mesh
234-
235- # Check if distributed is on
236234 self .dist_on = dist .is_available () and dist .is_initialized ()
237235
238- # Get global attributes
239- if self .dist_on :
240- self .global_rank = dist .get_rank ()
241- self .world_size = dist .get_world_size ()
242- else :
243- self .global_rank = 0
244- self .world_size = 1
245-
246- # Get TP attributes. If TP is on, the TP setup is stored in the device mesh
247- if self .dist_on and device_mesh is not None :
248- self .tp_size = device_mesh .size ()
249- self .tp_group = device_mesh .get_group ()
250- # The src for any TP-scoped collective must be the global rank of TP-rank 0 of THIS rank's TP group.
251- # In DP x TP, that is not necessarily global rank 0.
236+ # These attributes depend on the global dist state
237+ self .global_rank = dist .get_rank () if self .dist_on else 0
238+ self .world_size = dist .get_world_size () if self .dist_on else 1
239+
240+ # These attributes depend on the TP state
241+ if self .dist_on and self .device_mesh is not None :
242+ self .tp_size = self .device_mesh .size ()
243+ self .tp_group = self .device_mesh .get_group ()
252244 self .tp_root_global_rank = dist .get_global_rank (self .tp_group , 0 )
253- self .tp_local_rank = device_mesh .get_local_rank ()
254- # Dedicated CPU-only (gloo) process group for requests broadcasts
255- self .ingress_group = dist .new_group (ranks = dist .get_process_group_ranks (self .tp_group ), backend = "gloo" )
245+ self .tp_local_rank = self .device_mesh .get_local_rank ()
246+ self .is_tp_driver = self .tp_local_rank == 0
247+ # If TP is on, we create a dedicate CPU group
248+ tp_ranks = dist .get_process_group_ranks (self .tp_group )
249+ self .ingress_group = dist .new_group (ranks = tp_ranks , backend = "gloo" )
256250 else :
257251 self .tp_size = 1
258252 self .tp_group = None
259253 self .tp_root_global_rank = 0
260254 self .tp_local_rank = 0
255+ self .is_tp_driver = False
261256 self .ingress_group = None
262257
263- # The TP-driver is the rank that owns the request queue and the scheduler decisions inside its TP group.
264- self .is_tp_driver = self .tp_local_rank == 0
265-
266- # Get DP attributes
258+ # These attributes depend on the DP state
267259 self .dp_rank = self .global_rank // self .tp_size
268260 self .dp_size = self .world_size // self .tp_size
269261
270262 def destroy_ingress_group (self ) -> None :
271263 """Destroys the ingress group."""
272264 if self .ingress_group is not None :
273265 dist .destroy_process_group (self .ingress_group )
274- self .ingress_group = None
266+ self .ingress_group = None
275267
276268 def tp_broadcast_from_rank_0 (self , value : torch .Tensor ) -> torch .Tensor :
277269 """Inside each TP group, broadcasts the given value from rank 0 to all other ranks."""
@@ -291,7 +283,7 @@ def tp_all_reduce_min(self, value: torch.Tensor) -> torch.Tensor:
291283 dist .all_reduce (value , op = dist .ReduceOp .MIN , group = self .tp_group )
292284 return value
293285
294- def tp_broadcast_object (self , obj : T ) -> T :
286+ def tp_broadcast_object (self , obj ) :
295287 """Inside each TP group, broadcasts an arbitrary picklable Python object from TP-rank 0 to all other ranks.
296288 Used to keep request ingress and cancellations consistent across TP workers without requiring all ranks to
297289 receive the same external request stream. Uses a dedicated CPU (gloo) `ingress_group` for broadcast."""
0 commit comments