@@ -246,24 +246,24 @@ def __init__(self, device_mesh: DeviceMesh | None) -> None:
246246 self .is_tp_driver = self .tp_local_rank == 0
247247 # If TP is on, we create a dedicate CPU group
248248 tp_ranks = dist .get_process_group_ranks (self .tp_group )
249- self .ingress_group = dist .new_group (ranks = tp_ranks , backend = "gloo" )
249+ self .cpu_comm_group = dist .new_group (ranks = tp_ranks , backend = "gloo" )
250250 else :
251251 self .tp_size = 1
252252 self .tp_group = None
253253 self .tp_root_global_rank = 0
254254 self .tp_local_rank = 0
255255 self .is_tp_driver = False
256- self .ingress_group = None
256+ self .cpu_comm_group = None
257257
258258 # These attributes depend on the DP state
259259 self .dp_rank = self .global_rank // self .tp_size
260260 self .dp_size = self .world_size // self .tp_size
261261
262- def destroy_ingress_group (self ) -> None :
263- """Destroys the ingress group."""
264- if self .ingress_group is not None :
265- dist .destroy_process_group (self .ingress_group )
266- self .ingress_group = None
262+ def destroy_cpu_comm_group (self ) -> None :
263+ """Destroys the CPU comm group."""
264+ if self .cpu_comm_group is not None :
265+ dist .destroy_process_group (self .cpu_comm_group )
266+ self .cpu_comm_group = None
267267
268268 def tp_broadcast_from_rank_0 (self , value : torch .Tensor ) -> torch .Tensor :
269269 """Inside each TP group, broadcasts the given value from rank 0 to all other ranks."""
@@ -272,9 +272,9 @@ def tp_broadcast_from_rank_0(self, value: torch.Tensor) -> torch.Tensor:
272272 return value
273273
274274 def tp_broadcast_cpu_from_rank_0 (self , value : torch .Tensor ) -> torch .Tensor :
275- """Inside each TP group, broadcasts a CPU tensor from rank 0 over the gloo ingress group."""
275+ """Inside each TP group, broadcasts a CPU tensor from rank 0 over the gloo CPU comm group."""
276276 if self .tp_size > 1 :
277- dist .broadcast (value , src = self .tp_root_global_rank , async_op = False , group = self .ingress_group )
277+ dist .broadcast (value , src = self .tp_root_global_rank , async_op = False , group = self .cpu_comm_group )
278278 return value
279279
280280 def tp_all_reduce_min (self , value : torch .Tensor ) -> torch .Tensor :
@@ -283,15 +283,15 @@ def tp_all_reduce_min(self, value: torch.Tensor) -> torch.Tensor:
283283 dist .all_reduce (value , op = dist .ReduceOp .MIN , group = self .tp_group )
284284 return value
285285
286- def tp_broadcast_object (self , obj ) :
286+ def tp_broadcast_object (self , obj : T ) -> T :
287287 """Inside each TP group, broadcasts an arbitrary picklable Python object from TP-rank 0 to all other ranks.
288288 Used to keep request ingress and cancellations consistent across TP workers without requiring all ranks to
289- receive the same external request stream. Uses a dedicated CPU (gloo) `ingress_group ` for broadcast."""
289+ receive the same external request stream. Uses a dedicated CPU (gloo) `cpu_comm_group ` for broadcast."""
290290 if self .tp_size <= 1 :
291291 return obj
292292 holder = [obj ] if self .is_tp_driver else [None ]
293293 dist .broadcast_object_list (
294- holder , src = self .tp_root_global_rank , group = self .ingress_group , device = torch .device ("cpu" )
294+ holder , src = self .tp_root_global_rank , group = self .cpu_comm_group , device = torch .device ("cpu" )
295295 )
296296 return holder [0 ]
297297
0 commit comments