Skip to content

Commit 035f153

Browse files
committed
Nits
1 parent be8158c commit 035f153

1 file changed

Lines changed: 17 additions & 25 deletions

File tree

  • src/transformers/generation/continuous_batching

src/transformers/generation/continuous_batching/utils.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -231,47 +231,39 @@ class DistributedHelper:
231231

232232
def __init__(self, device_mesh: DeviceMesh | None) -> None:
233233
self.device_mesh = device_mesh
234-
235-
# Check if distributed is on
236234
self.dist_on = dist.is_available() and dist.is_initialized()
237235

238-
# Get global attributes
239-
if self.dist_on:
240-
self.global_rank = dist.get_rank()
241-
self.world_size = dist.get_world_size()
242-
else:
243-
self.global_rank = 0
244-
self.world_size = 1
245-
246-
# Get TP attributes. If TP is on, the TP setup is stored in the device mesh
247-
if self.dist_on and device_mesh is not None:
248-
self.tp_size = device_mesh.size()
249-
self.tp_group = device_mesh.get_group()
250-
# The src for any TP-scoped collective must be the global rank of TP-rank 0 of THIS rank's TP group.
251-
# In DP x TP, that is not necessarily global rank 0.
236+
# These attributes depend on the global dist state
237+
self.global_rank = dist.get_rank() if self.dist_on else 0
238+
self.world_size = dist.get_world_size() if self.dist_on else 1
239+
240+
# These attributes depend on the TP state
241+
if self.dist_on and self.device_mesh is not None:
242+
self.tp_size = self.device_mesh.size()
243+
self.tp_group = self.device_mesh.get_group()
252244
self.tp_root_global_rank = dist.get_global_rank(self.tp_group, 0)
253-
self.tp_local_rank = device_mesh.get_local_rank()
254-
# Dedicated CPU-only (gloo) process group for requests broadcasts
255-
self.ingress_group = dist.new_group(ranks=dist.get_process_group_ranks(self.tp_group), backend="gloo")
245+
self.tp_local_rank = self.device_mesh.get_local_rank()
246+
self.is_tp_driver = self.tp_local_rank == 0
247+
# If TP is on, we create a dedicate CPU group
248+
tp_ranks = dist.get_process_group_ranks(self.tp_group)
249+
self.ingress_group = dist.new_group(ranks=tp_ranks, backend="gloo")
256250
else:
257251
self.tp_size = 1
258252
self.tp_group = None
259253
self.tp_root_global_rank = 0
260254
self.tp_local_rank = 0
255+
self.is_tp_driver = False
261256
self.ingress_group = None
262257

263-
# The TP-driver is the rank that owns the request queue and the scheduler decisions inside its TP group.
264-
self.is_tp_driver = self.tp_local_rank == 0
265-
266-
# Get DP attributes
258+
# These attributes depend on the DP state
267259
self.dp_rank = self.global_rank // self.tp_size
268260
self.dp_size = self.world_size // self.tp_size
269261

270262
def destroy_ingress_group(self) -> None:
271263
"""Destroys the ingress group."""
272264
if self.ingress_group is not None:
273265
dist.destroy_process_group(self.ingress_group)
274-
self.ingress_group = None
266+
self.ingress_group = None
275267

276268
def tp_broadcast_from_rank_0(self, value: torch.Tensor) -> torch.Tensor:
277269
"""Inside each TP group, broadcasts the given value from rank 0 to all other ranks."""
@@ -291,7 +283,7 @@ def tp_all_reduce_min(self, value: torch.Tensor) -> torch.Tensor:
291283
dist.all_reduce(value, op=dist.ReduceOp.MIN, group=self.tp_group)
292284
return value
293285

294-
def tp_broadcast_object(self, obj: T) -> T:
286+
def tp_broadcast_object(self, obj):
295287
"""Inside each TP group, broadcasts an arbitrary picklable Python object from TP-rank 0 to all other ranks.
296288
Used to keep request ingress and cancellations consistent across TP workers without requiring all ranks to
297289
receive the same external request stream. Uses a dedicated CPU (gloo) `ingress_group` for broadcast."""

0 commit comments

Comments
 (0)