Skip to content

Commit eee5471

Browse files
committed
Renames
1 parent 035f153 commit eee5471

3 files changed

Lines changed: 14 additions & 14 deletions

File tree

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def stop(self, block: bool = True, timeout: float | None = None, keep_for_next_s
601601
# If the manager is not being kept for next session, we clear the batch processor
602602
if not keep_for_next_session:
603603
self.batch_processor = None
604-
self.distributed_helper.destroy_ingress_group()
604+
self.distributed_helper.destroy_cpu_comm_group()
605605
# Otherwise, we keep the batch processor and cache the manager as a model attribute
606606
else:
607607
logger.info("Continuous batching manager will be kept for next session.")

src/transformers/generation/continuous_batching/utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -246,24 +246,24 @@ def __init__(self, device_mesh: DeviceMesh | None) -> None:
246246
self.is_tp_driver = self.tp_local_rank == 0
247247
# If TP is on, we create a dedicate CPU group
248248
tp_ranks = dist.get_process_group_ranks(self.tp_group)
249-
self.ingress_group = dist.new_group(ranks=tp_ranks, backend="gloo")
249+
self.cpu_comm_group = dist.new_group(ranks=tp_ranks, backend="gloo")
250250
else:
251251
self.tp_size = 1
252252
self.tp_group = None
253253
self.tp_root_global_rank = 0
254254
self.tp_local_rank = 0
255255
self.is_tp_driver = False
256-
self.ingress_group = None
256+
self.cpu_comm_group = None
257257

258258
# These attributes depend on the DP state
259259
self.dp_rank = self.global_rank // self.tp_size
260260
self.dp_size = self.world_size // self.tp_size
261261

262-
def destroy_ingress_group(self) -> None:
263-
"""Destroys the ingress group."""
264-
if self.ingress_group is not None:
265-
dist.destroy_process_group(self.ingress_group)
266-
self.ingress_group = None
262+
def destroy_cpu_comm_group(self) -> None:
263+
"""Destroys the CPU comm group."""
264+
if self.cpu_comm_group is not None:
265+
dist.destroy_process_group(self.cpu_comm_group)
266+
self.cpu_comm_group = None
267267

268268
def tp_broadcast_from_rank_0(self, value: torch.Tensor) -> torch.Tensor:
269269
"""Inside each TP group, broadcasts the given value from rank 0 to all other ranks."""
@@ -272,9 +272,9 @@ def tp_broadcast_from_rank_0(self, value: torch.Tensor) -> torch.Tensor:
272272
return value
273273

274274
def tp_broadcast_cpu_from_rank_0(self, value: torch.Tensor) -> torch.Tensor:
275-
"""Inside each TP group, broadcasts a CPU tensor from rank 0 over the gloo ingress group."""
275+
"""Inside each TP group, broadcasts a CPU tensor from rank 0 over the gloo CPU comm group."""
276276
if self.tp_size > 1:
277-
dist.broadcast(value, src=self.tp_root_global_rank, async_op=False, group=self.ingress_group)
277+
dist.broadcast(value, src=self.tp_root_global_rank, async_op=False, group=self.cpu_comm_group)
278278
return value
279279

280280
def tp_all_reduce_min(self, value: torch.Tensor) -> torch.Tensor:
@@ -283,15 +283,15 @@ def tp_all_reduce_min(self, value: torch.Tensor) -> torch.Tensor:
283283
dist.all_reduce(value, op=dist.ReduceOp.MIN, group=self.tp_group)
284284
return value
285285

286-
def tp_broadcast_object(self, obj):
286+
def tp_broadcast_object(self, obj: T) -> T:
287287
"""Inside each TP group, broadcasts an arbitrary picklable Python object from TP-rank 0 to all other ranks.
288288
Used to keep request ingress and cancellations consistent across TP workers without requiring all ranks to
289-
receive the same external request stream. Uses a dedicated CPU (gloo) `ingress_group` for broadcast."""
289+
receive the same external request stream. Uses a dedicated CPU (gloo) `cpu_comm_group` for broadcast."""
290290
if self.tp_size <= 1:
291291
return obj
292292
holder = [obj] if self.is_tp_driver else [None]
293293
dist.broadcast_object_list(
294-
holder, src=self.tp_root_global_rank, group=self.ingress_group, device=torch.device("cpu")
294+
holder, src=self.tp_root_global_rank, group=self.cpu_comm_group, device=torch.device("cpu")
295295
)
296296
return holder[0]
297297

tests/generation/test_continuous_batching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def test_distributed_helper_no_dist(self) -> None:
504504
self.assertEqual(helper.dp_size, 1)
505505
self.assertTrue(helper.is_tp_driver)
506506
self.assertIsNone(helper.tp_group)
507-
self.assertIsNone(helper.ingress_group)
507+
self.assertIsNone(helper.cpu_comm_group)
508508

509509
# Tensor and object broadcasts should be no-ops without a TP group
510510
tensor = torch.tensor([1.0, 2.0])

0 commit comments

Comments
 (0)