Skip to content

Commit be8158c

Browse files
committed
CC review
1 parent e04ce5e commit be8158c

6 files changed

Lines changed: 41 additions & 22 deletions

File tree

benchmark_v2/benchmark_scripts/continuous_batching_overall.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def score(outputs) -> float:
102102

103103

104104
# Data helpers
105-
def get_tokenized_gms8k(
105+
def get_tokenized_gsm8k(
106106
tokenizer: AutoTokenizer, n_fewshot: int = 8
107107
) -> tuple[list[list[int]], Callable[[Any], float]]:
108108
"""GSM8K-Platinum few-shot inputs and scorer using the same lighteval extractive_match as the gsm8k task."""
@@ -323,7 +323,7 @@ def diff(cur: float | None, base: float | None) -> str:
323323

324324
# GSM8K benchmarks (256 max new tokens) — gsm8k_platinum dataset, 8-shot, lighteval extractive_match
325325
tokenizer = AutoTokenizer.from_pretrained(cli_args.model_id, padding_side="left")
326-
gsm8k_data, gsm8k_score_fn = get_tokenized_gms8k(tokenizer)
326+
gsm8k_data, gsm8k_score_fn = get_tokenized_gsm8k(tokenizer)
327327

328328
## No options
329329
results.add_benchmark(

src/transformers/generation/configuration_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1748,7 +1748,8 @@ class ContinuousBatchingConfig:
17481748
disable_nccl_graph_mixing: bool = True
17491749

17501750
def __post_init__(self):
1751-
if self.disable_nccl_graph_mixing:
1751+
# Only turn off graph mixing support if TP is on
1752+
if self.disable_nccl_graph_mixing and int(os.environ.get("WORLD_SIZE", "1")) > 1:
17521753
os.environ.setdefault("NCCL_GRAPH_MIXING_SUPPORT", "0")
17531754

17541755
def account_for_cb_deprecated_arguments(

src/transformers/generation/continuous_batching/cache_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def compute_hash(self, parent_hash: int | None, tokens: list[int], group_id: int
279279
"""Computes the hash of a block identified by the (tokens) it contains, its (parent_hash) and the layer
280280
(group_id) it belong to. If the block has no parent, the parent hash is None. Uses blake2b for a deterministic
281281
64-bit digest that is stable across processes (unlike Python's salted built-in `hash`)."""
282+
# NOTE: blake2b is ~10–20× slower than hash() here; consider gating by tp_size>1 or switching to xxhash.
282283
h = hashlib.blake2b(digest_size=8)
283284
if parent_hash is not None:
284285
h.update(parent_hash.to_bytes(8, "little", signed=False))

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,18 +260,19 @@ def _get_new_requests(self) -> None:
260260
"""Pull new requests and cancellations from the queues and apply them to the scheduler. If the process is a TP
261261
driver, the input_queue and cancel_queue are not None and the process will drain them. Otherwise, the process
262262
will wait for the TP driver to send a payload containing the new requests and cancellations."""
263-
# Only drains queues if this process is a TP driver
263+
# On the TP driver, drain the queues; non-driver ranks start from an empty tuple that gets overwritten by the
264+
# broadcast below.
265+
payload: tuple[list[RequestState], list[str]] = ([], [])
264266
if self.input_queue is not None and self.cancel_queue is not None:
265-
new_states = drain_queue(self.input_queue)
266-
cancellations = drain_queue(self.cancel_queue)
267-
payload = (new_states, cancellations)
268-
# Otherwise, the payload is None
269-
else:
270-
payload = ([], [])
267+
payload = (drain_queue(self.input_queue), drain_queue(self.cancel_queue))
268+
269+
# Cheap CPU/gloo presence check: skip the (pickled) object broadcast entirely when there is nothing to send.
270+
presence = torch.tensor([len(payload[0]) + len(payload[1])], dtype=torch.int64)
271+
self.distributed_helper.tp_broadcast_cpu_from_rank_0(presence)
272+
if presence.item() == 0:
273+
return
271274

272-
# Broadcast within the TP group. No-op when tp_size == 1, returns the driver's payload unchanged.
273-
payload = self.distributed_helper.tp_broadcast_object(payload)
274-
new_states, cancellations = payload
275+
new_states, cancellations = self.distributed_helper.tp_broadcast_object(payload)
275276

276277
# All ranks apply the same updates in the same order.
277278
for state in new_states:

src/transformers/generation/continuous_batching/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,12 @@ def tp_broadcast_from_rank_0(self, value: torch.Tensor) -> torch.Tensor:
279279
dist.broadcast(value, src=self.tp_root_global_rank, async_op=False, group=self.tp_group)
280280
return value
281281

282+
def tp_broadcast_cpu_from_rank_0(self, value: torch.Tensor) -> torch.Tensor:
283+
"""Inside each TP group, broadcasts a CPU tensor from rank 0 over the gloo ingress group."""
284+
if self.tp_size > 1:
285+
dist.broadcast(value, src=self.tp_root_global_rank, async_op=False, group=self.ingress_group)
286+
return value
287+
282288
def tp_all_reduce_min(self, value: torch.Tensor) -> torch.Tensor:
283289
"""Inside each TP group, all-reduces a tensor with the MIN op. No-op when TP is off."""
284290
if self.tp_size > 1:
@@ -300,7 +306,7 @@ def tp_broadcast_object(self, obj: T) -> T:
300306
def maybe_warn_nccl_graph_mixing(self) -> None:
301307
"""Throws a warning if TP is on and NCCL's graph mixing support was supposed to be disabled but isn't. That can
302308
happen if the distributed group is created before graph mixing is disabled. Typically, if the model is
303-
initialized before the ContinousBatchingConfig is created."""
309+
initialized before the ContinuousBatchingConfig is created."""
304310
tp_on = self.tp_size > 1
305311
graph_mixing_not_disabled = os.environ.get("NCCL_GRAPH_MIXING_SUPPORT") != "0"
306312
if tp_on and graph_mixing_not_disabled:

tests/generation/test_continuous_batching.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def test_continuous_batching_will_allocation_be_successful(
356356
num_free_blocks: int,
357357
expected_result: bool,
358358
) -> None:
359-
"""Test the will_allocation_be_successful method of PagedAttentionCache, overloading the elevant attributes of
359+
"""Test the will_allocation_be_successful method of PagedAttentionCache, overloading the relevant attributes of
360360
a dummy cache."""
361361

362362
if torch_device is None: # this check which should always pass and helps with type checking
@@ -532,15 +532,21 @@ def test_distributed_helper_set_tp_seed_no_dist(self) -> None:
532532
helper.set_tp_seed(seed=None, model_device=torch.device("cpu"))
533533

534534
def test_continuous_batching_config_disables_nccl_graph_mixing(self) -> None:
535-
"""Test that constructing a ContinuousBatchingConfig sets NCCL_GRAPH_MIXING_SUPPORT=0 by default and only sets
536-
it when the disable_nccl_graph_mixing flag is on."""
537-
original = os.environ.pop("NCCL_GRAPH_MIXING_SUPPORT", None)
535+
"""Test that ContinuousBatchingConfig sets NCCL_GRAPH_MIXING_SUPPORT=0 only under a distributed launch
536+
(WORLD_SIZE > 1) and respects the disable_nccl_graph_mixing flag."""
537+
original_nccl = os.environ.pop("NCCL_GRAPH_MIXING_SUPPORT", None)
538+
original_ws = os.environ.pop("WORLD_SIZE", None)
538539
try:
539-
# Default: env var is set to "0"
540+
# Single-GPU launch (no WORLD_SIZE): env var is left untouched
541+
ContinuousBatchingConfig()
542+
self.assertNotIn("NCCL_GRAPH_MIXING_SUPPORT", os.environ)
543+
544+
# Distributed launch (WORLD_SIZE > 1): env var is set to "0"
545+
os.environ["WORLD_SIZE"] = "2"
540546
ContinuousBatchingConfig()
541547
self.assertEqual(os.environ.get("NCCL_GRAPH_MIXING_SUPPORT"), "0")
542548

543-
# Explicitly disabled flag: env var is left untouched
549+
# Explicitly disabled flag: env var is left untouched even under a distributed launch
544550
os.environ.pop("NCCL_GRAPH_MIXING_SUPPORT", None)
545551
ContinuousBatchingConfig(disable_nccl_graph_mixing=False)
546552
self.assertNotIn("NCCL_GRAPH_MIXING_SUPPORT", os.environ)
@@ -550,10 +556,14 @@ def test_continuous_batching_config_disables_nccl_graph_mixing(self) -> None:
550556
ContinuousBatchingConfig()
551557
self.assertEqual(os.environ.get("NCCL_GRAPH_MIXING_SUPPORT"), "1")
552558
finally:
553-
if original is None:
559+
if original_nccl is None:
554560
os.environ.pop("NCCL_GRAPH_MIXING_SUPPORT", None)
555561
else:
556-
os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = original
562+
os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = original_nccl
563+
if original_ws is None:
564+
os.environ.pop("WORLD_SIZE", None)
565+
else:
566+
os.environ["WORLD_SIZE"] = original_ws
557567

558568

559569
@require_torch_accelerator

0 commit comments

Comments
 (0)