diff --git a/benchmark_v2/benchmark_scripts/continuous_batching_overall.py b/benchmark_v2/benchmark_scripts/continuous_batching_overall.py index 4833f137456f..01069e6d7d30 100644 --- a/benchmark_v2/benchmark_scripts/continuous_batching_overall.py +++ b/benchmark_v2/benchmark_scripts/continuous_batching_overall.py @@ -9,12 +9,18 @@ import gc import json import time +import types +from collections.abc import Callable from dataclasses import asdict, dataclass from pathlib import Path from typing import Any -import datasets import torch +from lighteval.models.model_output import ModelResponse +from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig +from lighteval.tasks.prompt_manager import PromptManager +from lighteval.tasks.registry import Registry +from lighteval.tasks.requests import Doc from tabulate import tabulate from transformers import AutoModelForCausalLM, AutoTokenizer, ContinuousBatchingConfig, GenerationConfig @@ -24,16 +30,101 @@ RESULTS_DIR = Path(__file__).parent.parent / "benchmark_results/cb_overall/" +def _fmt(val: Any, spec: str = "", missing: str = "X") -> str: + """Format `val` per `spec`, or return `missing` if val is None.""" + return format(val, spec) if val is not None else missing + + +def _build_gsm8k_platinum_module() -> types.ModuleType: + """Define the gsm8k_platinum custom task inline so lighteval's Registry can pick it up via `custom_tasks=`.""" + + def gsm8k_platinum_prompt(line, task_name=None): + return Doc( + task_name=task_name, + query=f"Question: {line['question']}\nAnswer:", + choices=[f" {line['answer']}"], + gold_index=0, + ) + + metrics = list(Registry().load_all_task_configs()["gsm8k"].metrics) + + mod = types.ModuleType("_gsm8k_platinum_inline") + mod.TASKS_TABLE = [ + LightevalTaskConfig( + name="gsm8k_platinum", + prompt_function=gsm8k_platinum_prompt, + hf_repo="madrylab/gsm8k-platinum", + hf_subset="main", + evaluation_splits=("test",), + few_shots_split="test", + few_shots_select="random_sampling", + generation_size=256, + stop_sequence=["Question:"], + metrics=metrics, + ), + ] + return mod + + +def _build_lighteval_inputs_scorer( + tokenizer: AutoTokenizer, + *, + task_spec: str, + task_name: str, + use_chat_template: bool, + custom_tasks: Any = None, + primary_metric: str | None = None, + stop_sequences: tuple[str, ...] = (), +) -> tuple[list[list[int]], Callable[[Any], float]]: + """Tokenize prompts and build a per-sample scorer for any lighteval task.""" + r = Registry(tasks=task_spec, **({"custom_tasks": custom_tasks} if custom_tasks else {})) + metric = r.task_to_configs[task_name][0].metrics[0] + tasks_dict = r.load_tasks() + LightevalTask.load_datasets(tasks_dict, 1) + docs = next(iter(tasks_dict.values())).get_docs() + + pm = PromptManager(use_chat_template=use_chat_template, tokenizer=tokenizer, system_prompt=None) + prompts = [pm.prepare_prompt(doc) for doc in docs] + inputs = tokenizer(prompts, add_special_tokens=not use_chat_template)["input_ids"] + + def score(outputs) -> float: + scores = [] + for doc, (_, out) in zip(docs, outputs.items()): + text = tokenizer.decode(out.generated_tokens, skip_special_tokens=True) + for s in stop_sequences: + text = text.split(s, 1)[0] + value = metric.sample_level_fn.compute(doc, ModelResponse(text=[text])) + # Grouped metrics return a dict keyed by sub-metric — pick the primary one. + scores.append(value[primary_metric] if isinstance(value, dict) else value) + return sum(scores) / len(scores) + + return inputs, score + + # Data helpers -def get_tokenized_gms8k(tokenizer: AutoTokenizer) -> list[list[int]]: - """Tokenize the GSM8K questions as chat prompts.""" - dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") - batched_inputs = [] - for item in dataset: - messages = [{"role": "user", "content": item["question"]}] - inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True) # type: ignore - batched_inputs.append(inputs if isinstance(inputs, list) else inputs["input_ids"]) - return batched_inputs +def get_tokenized_gsm8k( + tokenizer: AutoTokenizer, n_fewshot: int = 8 +) -> tuple[list[list[int]], Callable[[Any], float]]: + """GSM8K-Platinum few-shot inputs and scorer using the same lighteval extractive_match as the gsm8k task.""" + return _build_lighteval_inputs_scorer( + tokenizer, + task_spec=f"gsm8k_platinum|{n_fewshot}", + task_name="gsm8k_platinum", + use_chat_template=False, + custom_tasks=_build_gsm8k_platinum_module(), + stop_sequences=("Question:",), + ) + + +def get_tokenized_ifeval(tokenizer: AutoTokenizer) -> tuple[list[list[int]], Callable[[Any], float]]: + """IFEval inputs (chat-templated, 0-shot) and scorer reporting prompt-level strict accuracy.""" + return _build_lighteval_inputs_scorer( + tokenizer, + task_spec="ifeval|0", + task_name="ifeval", + use_chat_template=True, + primary_metric="prompt_level_strict_acc", + ) def get_random_data(batch_size: int, num_tokens: int, vocab_size: int = 16000) -> list[list[int]]: @@ -57,6 +148,7 @@ class BenchmarkEntry: num_tokens: int | None = None throughput_tok_per_sec: float | None = None peak_memory_gb: float | None = None + accuracy: float | None = None error: str | None = None @@ -69,9 +161,10 @@ def _config_summary(cfg: Any) -> dict[str, Any]: class BenchmarkResults: """Holds all CB benchmark runs and the shared model they execute against.""" - def __init__(self, model_id: str, attn_impl: str): + def __init__(self, model_id: str, attn_impl: str, tp_size: int = 1): self.model_id = model_id self.attn_impl = attn_impl + self.tp_size = tp_size self.entries: list[BenchmarkEntry] = [] def cleanup(self) -> None: @@ -80,13 +173,11 @@ def cleanup(self) -> None: torch.cuda.reset_peak_memory_stats() def _get_model(self) -> Any: - model = None self.cleanup() - model = AutoModelForCausalLM.from_pretrained( - self.model_id, attn_implementation=self.attn_impl, device_map="auto" - ) - model = model.eval() - return model + # tp_plan and device_map are mutually exclusive — TP uses its own placement. + placement = {"tp_plan": "auto"} if self.tp_size > 1 else {"device_map": 0} + model = AutoModelForCausalLM.from_pretrained(self.model_id, attn_implementation=self.attn_impl, **placement) + return model.eval() def add_benchmark( self, @@ -95,13 +186,12 @@ def add_benchmark( cb_config: ContinuousBatchingConfig, gen_config: GenerationConfig | None = None, label: str | None = None, + score_fn: Callable[[Any], float] | None = None, ) -> BenchmarkEntry: """Run one CB benchmark and record time, tokens, and peak memory.""" gen_config = GenerationConfig() if gen_config is None else gen_config gen_config.max_new_tokens = max_new_tokens - # Disable EOS so every request runs to max_new_tokens — consistent benchmarking. - gen_config.eos_token_id = -1 model = self._get_model() @@ -135,9 +225,12 @@ def add_benchmark( entry.num_tokens = num_tokens entry.throughput_tok_per_sec = num_tokens / gen_time if gen_time > 0 else 0.0 entry.peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3) + if score_fn is not None: + entry.accuracy = score_fn(outputs) print( f" {gen_time:.2f}s, {num_tokens} tokens, " f"{entry.throughput_tok_per_sec:.2f} tok/s, peak {entry.peak_memory_gb:.2f} GB" + + (f", acc {entry.accuracy:.3f}" if entry.accuracy is not None else "") ) except Exception as e: entry.error = str(e) @@ -184,10 +277,11 @@ def print_summary(self) -> None: "samples": e.num_samples, "avg_in": f"{e.avg_input_tokens:.1f}", "max_new": e.max_new_tokens, - "time (s)": f"{e.time_seconds:.2f}" if e.time_seconds is not None else "X", - "tokens": e.num_tokens if e.num_tokens is not None else "X", - "tok/s": f"{e.throughput_tok_per_sec:.2f}" if e.throughput_tok_per_sec is not None else "ERROR", - "mem (GB)": f"{e.peak_memory_gb:.2f}" if e.peak_memory_gb is not None else "X", + "time (s)": _fmt(e.time_seconds, ".2f"), + "tokens": _fmt(e.num_tokens, "d"), + "tok/s": _fmt(e.throughput_tok_per_sec, ".2f", "ERROR"), + "mem (GB)": _fmt(e.peak_memory_gb, ".2f"), + "acc": _fmt(e.accuracy, ".3f", "-"), } for e in self.entries ] @@ -195,24 +289,22 @@ def print_summary(self) -> None: def compare_to(self, baseline: "BenchmarkResults") -> None: """Print a side-by-side throughput comparison against a baseline run.""" - baseline_by_label = {e.label: e for e in baseline.entries} - rows = [] - for e in self.entries: - base = baseline_by_label.get(e.label) - base_tp = base.throughput_tok_per_sec if base else None - cur_tp = e.throughput_tok_per_sec - if isinstance(base_tp, (int, float)) and isinstance(cur_tp, (int, float)) and base_tp > 0: - diff_str = f"{(cur_tp - base_tp) / base_tp * 100:+.1f}%" - else: - diff_str = "N/A" - rows.append( - { - "label": e.label, - "baseline (tok/s)": f"{base_tp:.2f}" if isinstance(base_tp, (int, float)) else "N/A", - "current (tok/s)": (f"{cur_tp:.2f}" if isinstance(cur_tp, (int, float)) else (e.error or "N/A")), - "diff": diff_str, - } - ) + base_tps = {e.label: e.throughput_tok_per_sec for e in baseline.entries} + + def diff(cur: float | None, base: float | None) -> str: + if cur is None or not base: + return "N/A" + return f"{(cur - base) / base * 100:+.1f}%" + + rows = [ + { + "label": e.label, + "baseline (tok/s)": _fmt(base_tps.get(e.label), ".2f", "N/A"), + "current (tok/s)": _fmt(e.throughput_tok_per_sec, ".2f", e.error or "N/A"), + "diff": diff(e.throughput_tok_per_sec, base_tps.get(e.label)), + } + for e in self.entries + ] print(f"\nComparison against baseline (model={baseline.model_id}):") print(tabulate(rows, headers="keys", tablefmt="github")) @@ -224,30 +316,33 @@ def compare_to(self, baseline: "BenchmarkResults") -> None: parser.add_argument("--compare-to", type=str, default=None, help="Name of a previous run to compare against.") parser.add_argument("--model-id", type=str, default="meta-llama/Llama-3.1-8B-Instruct") parser.add_argument("--attn", type=str, default="kernels-community/flash-attn3") + parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size (1 = no TP).") cli_args = parser.parse_args() - results = BenchmarkResults(model_id=cli_args.model_id, attn_impl=cli_args.attn) - - # GSM8K benchmarks (256 max new tokens) + results = BenchmarkResults(model_id=cli_args.model_id, attn_impl=cli_args.attn, tp_size=cli_args.tp_size) + # GSM8K benchmarks (256 max new tokens) — gsm8k_platinum dataset, 8-shot, lighteval extractive_match tokenizer = AutoTokenizer.from_pretrained(cli_args.model_id, padding_side="left") - gsm8k_data = get_tokenized_gms8k(tokenizer) + gsm8k_data, gsm8k_score_fn = get_tokenized_gsm8k(tokenizer) ## No options results.add_benchmark( data=gsm8k_data, max_new_tokens=256, cb_config=ContinuousBatchingConfig(), + gen_config=GenerationConfig(eos_token_id=-1), label="gsm8k_default", + score_fn=gsm8k_score_fn, ) - ## With sampling + ## With sampling. Recommended chat sampling (T=0.6, top_p=0.9), low enough that math reasoning isn't derailed results.add_benchmark( data=gsm8k_data, max_new_tokens=256, cb_config=ContinuousBatchingConfig(), - gen_config=GenerationConfig(do_sample=True), + gen_config=GenerationConfig(eos_token_id=-1, do_sample=True, temperature=0.6, top_p=0.9), label="gsm8k_sampling", + score_fn=gsm8k_score_fn, ) ## With compile @@ -255,7 +350,9 @@ def compare_to(self, baseline: "BenchmarkResults") -> None: data=gsm8k_data, max_new_tokens=256, cb_config=ContinuousBatchingConfig(use_default_compile_configs=True), + gen_config=GenerationConfig(eos_token_id=-1), label="gsm8k_compile", + score_fn=gsm8k_score_fn, ) ## No decode fast path @@ -263,7 +360,29 @@ def compare_to(self, baseline: "BenchmarkResults") -> None: data=gsm8k_data, max_new_tokens=256, cb_config=ContinuousBatchingConfig(max_blocks_per_request=0), + gen_config=GenerationConfig(eos_token_id=-1), label="gsm8k_no_fast_decode", + score_fn=gsm8k_score_fn, + ) + + ## Bare-bones CB config + results.add_benchmark( + data=gsm8k_data, + max_new_tokens=256, + cb_config=ContinuousBatchingConfig(max_blocks_per_request=0, use_async_batching=False, use_cuda_graph=False), + gen_config=GenerationConfig(eos_token_id=-1), + label="gsm8k_bare_bones", + score_fn=gsm8k_score_fn, + ) + + # IFEval: 0-shot chat prompts; uses real EOS so instruction-following metrics see the model's natural stop. + ifeval_data, ifeval_score_fn = get_tokenized_ifeval(tokenizer) + results.add_benchmark( + data=ifeval_data, + max_new_tokens=1280, + cb_config=ContinuousBatchingConfig(), + label="ifeval_default", + score_fn=ifeval_score_fn, ) # Raw benchmarks (synthetic data, variable max new tokens) @@ -274,6 +393,7 @@ def compare_to(self, baseline: "BenchmarkResults") -> None: data=get_random_data(batch_size=32, num_tokens=256), max_new_tokens=length, cb_config=ContinuousBatchingConfig(use_default_compile_configs=True), + gen_config=GenerationConfig(eos_token_id=-1), label=f"rollouts_{length}", ) @@ -282,6 +402,7 @@ def compare_to(self, baseline: "BenchmarkResults") -> None: data=get_random_data(batch_size=20, num_tokens=256), max_new_tokens=256, cb_config=ContinuousBatchingConfig(num_blocks=16), + gen_config=GenerationConfig(eos_token_id=-1), label="few_blocks", ) @@ -290,17 +411,16 @@ def compare_to(self, baseline: "BenchmarkResults") -> None: data=get_random_data(batch_size=50, num_tokens=256), max_new_tokens=256, cb_config=ContinuousBatchingConfig(), - gen_config=GenerationConfig(do_sample=True, num_return_sequences=8), + gen_config=GenerationConfig(eos_token_id=-1, do_sample=True, num_return_sequences=8), label="multi_return_seq", ) - # Post processing and display - - results.print_summary() - - if cli_args.compare_to: - baseline = BenchmarkResults.load_most_recent(cli_args.compare_to) - results.compare_to(baseline=baseline) - - if cli_args.name: - results.save(cli_args.name) + # Post processing and display. Only on rank 0 in TP runs to avoid duplicate output / file writes. + is_rank_zero = not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + if is_rank_zero: + results.print_summary() + if cli_args.compare_to: + baseline = BenchmarkResults.load_most_recent(cli_args.compare_to) + results.compare_to(baseline=baseline) + if cli_args.name: + results.save(cli_args.name) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index aa1b8cc2cd79..6da3d40dbfcc 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -1653,6 +1653,8 @@ class ContinuousBatchingConfig: Scheduler type to use. return_logprobs (`bool`, *optional*, defaults to `False`): Whether to return log probabilities along with the generated tokens. + seed (`int | None`, *optional*): + An optional seed for generation. If not specified, the internal seed will be set to a random value. cpu_offload_space (`float`, *optional*, defaults to 0.0): CPU swap space in GiB for KV cache offloading. A pre-allocated pinned CPU buffer of this size is created at initialization. When the GPU cache is full, evicted requests' KV caches are copied here @@ -1666,6 +1668,8 @@ class ContinuousBatchingConfig: Enable per-request logits processor parameters. Default is False. drop_unsupported_processors (`bool`, *optional*, defaults to `True`): Remove unsupported logits processors instead of erroring. Default is True. + disable_nccl_graph_mixing (`bool`, *optional*, defaults to `True`): + Disable NCCL's safety net for parallel graph-captured comms. Never happens in CB and gives TP a perf boost. """ # Size of each KV cache block @@ -1719,6 +1723,9 @@ class ContinuousBatchingConfig: # probabilities will be returned along with the generated tokens in the generation output. return_logprobs: bool = False + # An optional seed for generation. If not specified, the internal seed will be set to a random value. + seed: int | None = None + # CPU swap space in GiB for KV cache offloading. When the GPU cache is full and a request must be evicted, its KV # cache is copied to this pre-allocated pinned CPU buffer instead of being discarded. Default to 0.0 GiB. You can # also set this to None to dimension the pool using only the safety threshold, but this will error out if psutil is @@ -1739,44 +1746,18 @@ class ContinuousBatchingConfig: # are kept but warnings are logged for unsupported/unknown ones. drop_unsupported_processors: bool = True - def account_for_cb_deprecated_arguments( - self, - max_queue_size: int = 0, - q_padding_interval_size: int = 0, - kv_padding_interval_size: int = 0, - allow_block_sharing: bool = True, - use_async_batching: bool | None = None, - max_cached_graphs: int = 0, - ) -> None: - """Some arguments given to `generate_batch`, `init_continuous_batching` or `continuous_batching_context_manager` - are now deprecated and are expected inside the continuous batching config. This method checks if any were - passed and accounts for them in the continuous batching config. It raises a deprecation warning if any were - passed. - """ - kwargs_to_warn = [] - if max_queue_size > 0: - kwargs_to_warn.append("max_queue_size") - self.max_queue_size = max_queue_size - if q_padding_interval_size > 0: - kwargs_to_warn.append("q_padding_interval_size") - self.q_padding_interval_size = q_padding_interval_size - if kv_padding_interval_size > 0: - kwargs_to_warn.append("kv_padding_interval_size") - self.kv_padding_interval_size = kv_padding_interval_size - if not allow_block_sharing: # config default is True, so False means the user explicitly set it to False - kwargs_to_warn.append("allow_block_sharing") - self.allow_block_sharing = allow_block_sharing - if use_async_batching is not None: - kwargs_to_warn.append("use_async_batching") - self.use_async_batching = use_async_batching - if max_cached_graphs > 0: - kwargs_to_warn.append("max_cached_graphs") - self.max_cached_graphs = max_cached_graphs - if kwargs_to_warn: + # Disable NCCL's safety net for parallel graph-captured communications. This means it is no longer safe to replay a + # CUDA graph with NCCL communication at the same time as 1. another CUDA graph with captured comms 2. an eager comm. + # This is turned on by default because the above never happens in CB and this gives a nice perf boost. + disable_nccl_graph_mixing: bool = True + + def __post_init__(self): + # Only turn off graph mixing support if TP is on + if self.disable_nccl_graph_mixing and int(os.environ.get("WORLD_SIZE", "1")) > 1: logger.warning( - "The following arguments were provided to a continuous batching entry point instead of being passed " - "through the continuous_batching_config: " + ", ".join(kwargs_to_warn) + "Setting NCCL_GRAPH_MIXING_SUPPORT = 0 because disable_nccl_graph_mixing is True and WORLD_SIZE > 1." ) + os.environ.setdefault("NCCL_GRAPH_MIXING_SUPPORT", "0") @property def cuda_graph_booleans(self) -> tuple[bool, bool]: @@ -1789,5 +1770,5 @@ def cuda_graph_booleans(self) -> tuple[bool, bool]: @property def fallback_max_blocks_per_request(self) -> int: - """Returns the max blocks per request.""" + """Fallback if no user-hint is given and decode path is available.""" return 32 diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index 393b81530f04..32f4cc7caafb 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -22,6 +22,7 @@ from ...utils.generic import is_flash_attention_requested from ...utils.metrics import attach_tracer, traced from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator +from .distributed import DistributedHelper from .initialization import resolve_max_memory_percent from .requests import RequestState, RequestStatus, get_device_and_memory_breakdown, logger @@ -122,8 +123,9 @@ def __init__( config: PreTrainedConfig, continuous_batching_config: ContinuousBatchingConfig, device: torch.device | str, + distributed_helper: DistributedHelper, + tp_plan: dict[str, Any], dtype: torch.dtype = torch.float16, - tp_size: int | None = None, ) -> None: """Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has only full attention layers. @@ -132,8 +134,9 @@ def __init__( config: Model configuration continuous_batching_config: Continuous batching configuration containing cache parameters device: Device for the cache tensors + distributed_helper: TP-aware helper. Used to dispatch attention heads and ensure coherent cache size + tp_plan: Tensor parallelism plan dtype: Data type of the cache - tp_size: Tensor parallelism size """ self.config = config self.dtype = dtype @@ -163,14 +166,23 @@ def __init__( self.layer_index_to_group_indices[layer] = (i, j) self.sliding_windows[layer] = sliding_window - # Handle TP (or dont) - if tp_size is not None and tp_size > 1: + # Check if the KV heads are part of the TP plan. If they are not, the cache does not need plan for TP. + # TODO: this is fragile. If your model fails to TP properly because of this, please open an issue. + kv_is_tp = True + for key in ["layers.*.self_attn.k_proj", "layers.*.self_attn.v_proj"]: + if not (key in tp_plan or "model." + key in tp_plan): + kv_is_tp = False + break + + # If the KV heads are TP'ed, each KV head is dispatched to a different GPU, so the effective number of KV heads + # per GPU is simply divided by the TP size + tp_size = distributed_helper.tp_size + if tp_size > 1 and kv_is_tp: if self.num_key_value_heads % tp_size != 0: raise ValueError( f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}." ) - # If the model is using tensor parallelism, we need to adjust the number of heads accordingly. - # self.num_key_value_heads //= tp_size # TODO: why is this commented out? + self.num_key_value_heads //= tp_size # Infer number of blocks and max batch tokens page_size = self.head_dim * self.num_key_value_heads @@ -214,6 +226,12 @@ def __init__( cache_dtype=self.dtype, ) + # For TP, align num_blocks and max_batch_tokens to the minimal value across the TP group + if tp_size > 1: + sync = torch.tensor([num_blocks, max_batch_tokens], device=self.device, dtype=torch.int64) + distributed_helper.tp_all_reduce_min(sync) + num_blocks, max_batch_tokens = int(sync[0].item()), int(sync[1].item()) + # Add the inferred attributes to the class self.num_blocks = num_blocks self.max_batch_tokens = max_batch_tokens @@ -270,7 +288,7 @@ def __init__( # We only use prefix sharing if the whole model has only full attention layers and block sharing is allowed self.use_prefix_sharing = self.allow_block_sharing and group_types == ["full_attention"] - self._block_manager = BlockManager(num_blocks, self.block_size) + self._block_manager = BlockManager(num_blocks, self.block_size, tp_on=tp_size > 1) self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests # For block table support, we lazy init the name of the block table key diff --git a/src/transformers/generation/continuous_batching/cache_manager.py b/src/transformers/generation/continuous_batching/cache_manager.py index 71d0d3c23f71..ffcb06738468 100644 --- a/src/transformers/generation/continuous_batching/cache_manager.py +++ b/src/transformers/generation/continuous_batching/cache_manager.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import hashlib from abc import ABC, abstractmethod +from array import array from collections import deque from collections.abc import Iterator from math import ceil @@ -73,10 +75,11 @@ class BlockManager: it is in use. """ - def __init__(self, num_blocks: int, block_size: int) -> None: + def __init__(self, num_blocks: int, block_size: int, tp_on: bool) -> None: """Initializes the block manager with a given number of blocks (num_blocks) of size (block_size).""" self.num_blocks = num_blocks self.block_size = block_size + self.tp_on = tp_on self._uninit_block_ids = deque(range(num_blocks)) self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set self._hash_to_id: dict[int, int] = {} @@ -276,7 +279,19 @@ def mark_shareable_blocks_as_complete( def compute_hash(self, parent_hash: int | None, tokens: list[int], group_id: int) -> int: """Computes the hash of a block identified by the (tokens) it contains, its (parent_hash) and the layer (group_id) it belong to. If the block has no parent, the parent hash is None.""" - return hash((parent_hash, tuple(tokens), group_id)) + # If TP is on, we cannot use python `hash` because it depends on the process (it's per-process salted) + # TODO: figure out if this is really a problem. Even if hashes diverge per-process, does that break anything? + if self.tp_on: + h = hashlib.blake2b(digest_size=8) + if parent_hash is not None: + h.update(parent_hash.to_bytes(8, "little", signed=False)) + h.update(array("i", tokens).tobytes()) + h.update(group_id.to_bytes(4, "little", signed=False)) + hash_ = int.from_bytes(h.digest(), "little", signed=False) + # Otherwise, use `hash` + else: + hash_ = hash((parent_hash, tuple(tokens), group_id)) + return hash_ class CacheAllocator(ABC): diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 90c939be0559..e7f276bb56fb 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -35,13 +35,14 @@ from ..logits_process import LogitsProcessorList from .cache import PagedAttentionCache from .cb_logits_processors import ContinuousBatchingLogitsProcessorList +from .distributed import DistributedHelper from .initialization import resolve_continuous_batching_config from .input_outputs import ContinuousBatchingAsyncIOs, ContinuousBatchingIOs from .model_runner import ModelRunner from .offloading_manager import OffloadingManager from .requests import GenerationOutput, RequestState, RequestStatus, logger from .scheduler import SCHEDULER_MAPPING, FIFOScheduler, Scheduler -from .utils import WorkloadHints +from .utils import WorkloadHints, drain_queue """ @@ -141,12 +142,14 @@ def __init__( generation_config: GenerationConfig, continuous_batching_config: ContinuousBatchingConfig, logit_processor: ContinuousBatchingLogitsProcessorList, - input_queue: queue.Queue, + input_queue: queue.Queue | None, + cancel_queue: queue.Queue | None, output_router: OutputRouter, stop_event: threading.Event, model_device: torch.device, model_dtype: torch.dtype, scheduler: Scheduler, + distributed_helper: DistributedHelper, ) -> None: """Initialize the continuous batch processor. @@ -156,28 +159,36 @@ def __init__( generation_config: The generation configuration continuous_batching_config: The continuous batching configuration logit_processor: The [`ContinuousBatchingLogitsProcessorList`] object used to process the logits. - input_queue: Queue for incoming requests + input_queue: Queue for incoming requests. Is None if this process is not a TP driver. + cancel_queue: Queue for cancellation request_ids. Is None if this process is not a TP driver. output_router: An [`OutputRouter`] object that routes outputs to handlers or the output queue. stop_event: Event to signal processing should stop model_device: Device for model inputs/outputs model_dtype: Data type for model inputs/outputs scheduler: The [`Scheduler`] to use + distributed_helper: The [`DistributedHelper`] to use """ self.cache = cache self.config = config self.cb_config = continuous_batching_config self.logit_processor = logit_processor self.input_queue = input_queue + self.cancel_queue = cancel_queue self.output_router = output_router self.stop_event = stop_event self.model_device = model_device self.model_dtype = model_dtype self.scheduler = scheduler + self.distributed_helper = distributed_helper # Generation-related attributes self.do_sample = getattr(generation_config, "do_sample", True) self.return_logprobs = continuous_batching_config.return_logprobs + # Get an integer seed for the TP group. Also work for no TP. + self.distributed_helper.set_tp_seed(continuous_batching_config.seed, model_device) + self.driver_stopped = False # will be set to True if the TP driver stops the generation loop + # Retrieve the size of the sliding window if there is one self.sliding_window = 1 if getattr(config, "sliding_window", None) is None else config.sliding_window @@ -245,25 +256,45 @@ def reset(self) -> None: self.inputs_and_outputs.reset() self.cache.free_all_requests() self.metrics = ContinuousBatchProcessorMetrics(self.cache.max_batch_tokens) + self.driver_stopped = False @traced - def _get_new_requests(self) -> None: - """Pull new requests from the input queue and add to waiting list.""" - while not self.input_queue.empty(): + def _get_new_requests(self) -> bool: + """Pull new requests and cancellations from the queues and apply them to the scheduler. In the context of TP, + only the TP driver of the TP group does this, and broadcasts the new_states / cancellations to other TP ranks. + Returns a boolean indicating if the TP driver for this group has stopped.""" + # The payload is filled for TP drivers only, it stays empty for other processes + payload = ([], []) + if self.input_queue is not None and self.cancel_queue is not None: + payload = (drain_queue(self.input_queue), drain_queue(self.cancel_queue)) + + # Cheap CPU-only comm to know if there is a payload to broadcast or if the driver is stopping + if self.stop_event.is_set(): + signal = -1 + else: + signal = len(payload[0]) + len(payload[1]) + signal = self.distributed_helper.tp_broadcast_int(signal) + + # If the signal is 0, it means the driver has nothing to send: stop here + if signal == 0: + return False + # Else if it is strictly below 0, it means the driver is stopping: do the same + elif signal < 0: + return True + # Otherwise, the payload size is above 0, so there is a payload to broadcast and unpack (no-op for TP size 1) + new_states, cancellations = self.distributed_helper.tp_broadcast_object(payload) + + # All ranks apply the same updates in the same order. + for state in new_states: try: - state = self.input_queue.get_nowait() - if state is None: # Sentinel value - continue self.logit_processor.check_kwargs(state.logit_processor_kwargs) self.scheduler.add_waiting_request(state) - - except queue.Empty: - break except Exception as e: logger.error(f"Error processing new request: {e}", exc_info=True) - state: RequestState = locals().get("state") - if state is not None: - self._handle_request_error(e, state) + self._handle_request_error(e, state) + for request_id in cancellations: + self.scheduler.set_request_cancellation(request_id) + return False @traced def _handle_request_error(self, error: Exception, state: RequestState) -> None: @@ -285,8 +316,10 @@ def prepare_next_batch(self) -> bool: """Prepare tensors and metadata for the next model forward pass. Returns True if there are requests to process, False otherwise.""" - # Get new requests from the queue, stop if there are no pending requests - self._get_new_requests() + # Get new requests from the queue. If the driver signaled collective stop, surface it to the manager. + self.driver_stopped = self._get_new_requests() + if self.driver_stopped: + return False cancelled_states = self.scheduler.clear_cancelled_requests() # Also free CPU-offloaded cache for cancelled states. This is CPU-only, so it isn't batched like D2H transfers for state in cancelled_states: @@ -491,6 +524,7 @@ def __init__( self.warmed_up = False # Set to True after warmup is completed. Useful for persistent managers. self.input_queue = queue.Queue(maxsize=continuous_batching_config.max_queue_size) + self.cancel_queue: queue.Queue[str] = queue.Queue() self._has_new_requests = threading.Event() self.output_router = OutputRouter() self.stop_event = threading.Event() @@ -500,6 +534,13 @@ def __init__( self._request_lock = threading.Lock() self.fatal_error: Exception | None = None + # Infer if this process is the driver of its own TP group + self.distributed_helper = DistributedHelper(device_mesh=getattr(self.model, "_device_mesh", None)) + self.is_tp_driver = self.distributed_helper.is_tp_driver + # If TP is on, check if NCCL graph mixing is disabled (helps with performance) + if continuous_batching_config.disable_nccl_graph_mixing: + self.distributed_helper.maybe_warn_nccl_graph_mixing() + # Generation config related arguments num_return_sequences = getattr(generation_config, "num_return_sequences", None) self.num_return_sequences = num_return_sequences if num_return_sequences is not None else 1 @@ -529,7 +570,7 @@ def switch_to_paged_attn(self, model: ProtoPretrainedModel) -> None: @traced def start(self) -> None: """Start the background generation thread.""" - if self._generation_thread is not None and self._generation_thread.is_alive(): + if self.is_running(): logger.warning("Manager thread is already running.") return self.stop_event.clear() @@ -578,9 +619,10 @@ def stop(self, block: bool = True, timeout: float | None = None, keep_for_next_s if block: self.join(stop_trigger_time, timeout) - # If the manager is not being kept for next session, we clear the batch processor + # If the manager is not being kept for next session, we clear the batch processor and destroy the CPU comm group if not keep_for_next_session: self.batch_processor = None + self.distributed_helper.destroy_cpu_comm_group() # Otherwise, we keep the batch processor and cache the manager as a model attribute else: logger.info("Continuous batching manager will be kept for next session.") @@ -617,8 +659,8 @@ def add_request( record_timestamps: bool = False, eos_token_id: int | list[int] | None = None, **logit_processor_kwargs: Any, - ) -> str: - """Add a new generation request to the queue. + ) -> str | None: + """Add a new generation request to the queue. If the process is not a TP driver, this is a no-op. Args: input_ids: Input token IDs to use as prompt @@ -630,13 +672,17 @@ def add_request( logit_processor_kwargs: Keyword arguments for the logits processor. Returns: - str: The request ID + str | None: The request ID if the process is a TP driver, None otherwise. """ if request_id is None: with self._request_lock: request_id = f"req_{self._request_counter}" self._request_counter += 1 + # If this process is not a TP driver, it does not enqueue new requests from this entry point + if not self.is_tp_driver: + return None # this value should never be used anyway because non-TP drivers do not enqueue requests + max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens eos_token_id = self.generation_config.eos_token_id if eos_token_id is None else eos_token_id @@ -692,13 +738,11 @@ def add_requests( return request_ids def cancel_request(self, request_id: str) -> None: - """Cancel a request by its ID. - - Args: - request_id: The ID of the request to cancel - """ - if self.batch_processor is not None: - self.batch_processor.scheduler.set_request_cancellation(request_id) + """Cancel a request by its ID. If this called from a process that is not a TP driver, it's a no-op: only TP + driver processes interact with the manager.""" + if self.is_tp_driver: + self.cancel_queue.put(request_id) + self._has_new_requests.set() # TODO:handle benchmarking properly when updating / fixing the requeue logic def get_result(self, request_id: str | None = None, timeout: float | None = None) -> GenerationOutput | None: @@ -776,11 +820,12 @@ def _generation_step(self) -> None: def _create_batch_processor(self) -> ContinuousBatchProcessor: # Create the PagedAttentionCache paged_attention_cache = PagedAttentionCache( - self.model.config, - self.continuous_batching_config, - self.model.device, - self.model.dtype, - tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting + config=self.model.config, + continuous_batching_config=self.continuous_batching_config, + device=self.model.device, + distributed_helper=self.distributed_helper, + tp_plan=getattr(self.model, "tp_plan", {}), + dtype=self.model.dtype, ) self._use_prefix_sharing = paged_attention_cache.use_prefix_sharing # update the approximation @@ -802,12 +847,14 @@ def _create_batch_processor(self) -> ContinuousBatchProcessor: generation_config=self.generation_config, continuous_batching_config=self.continuous_batching_config, logit_processor=self.logit_processor, - input_queue=self.input_queue, + input_queue=self.input_queue if self.is_tp_driver else None, + cancel_queue=self.cancel_queue if self.is_tp_driver else None, output_router=self.output_router, stop_event=self.stop_event, model_device=self.model.device, model_dtype=self.model.dtype, scheduler=scheduler(paged_attention_cache), + distributed_helper=self.distributed_helper, ) return batch_processor @@ -835,7 +882,8 @@ def _run_generation_loop(self) -> None: self._generation_step() self.current_batch += 1 - while (not self.stop_event.is_set()) or batch_processor.has_pending_requests(): + # The loop continues until the TP driver stops or there are no more pending requests + while (not batch_processor.driver_stopped) or batch_processor.has_pending_requests(): self._inner_generation_loop(batch_processor) self.current_batch += 1 @@ -903,7 +951,6 @@ def init_continuous_batching( generation_config: GenerationConfig | None = None, continuous_batching_config: ContinuousBatchingConfig | None = None, workload_hints: WorkloadHints | None = None, - **deprecated_kwargs, ) -> ContinuousBatchingManager: """Initialize a manager for continuous batching inference. @@ -912,9 +959,6 @@ def init_continuous_batching( continuous_batching_config: An optional continuous batching configuration workload_hints: Optional WorkloadHints to help the continuous batching manager make better decisions for default values - **deprecated_kwargs: Deprecated arguments that are now passed in the continuous_batching_config. Those are: - max_queue_size, q_padding_interval_size, kv_padding_interval_size, allow_block_sharing, - use_async_batching, max_cached_graphs Returns: `ContinuousBatchingManager`: The manager instance to add requests and retrieve results. """ @@ -947,7 +991,6 @@ def init_continuous_batching( continuous_batching_config = gen_config.continuous_batching_config else: continuous_batching_config = ContinuousBatchingConfig() - continuous_batching_config.account_for_cb_deprecated_arguments(**deprecated_kwargs) # Create and return the manager return ContinuousBatchingManager( @@ -975,7 +1018,6 @@ def continuous_batching_context_manager( persistent_manager: bool = False, warmup: bool = True, workload_hints: WorkloadHints | None = None, - **deprecated_kwargs, ) -> Generator[ContinuousBatchingManager]: """A context manager to safely use the continuous batching manager. Arguments are similar to the ones of `init_continuous_batching`, except for: @@ -987,7 +1029,6 @@ def continuous_batching_context_manager( generation_config=generation_config, continuous_batching_config=continuous_batching_config, workload_hints=workload_hints, - **deprecated_kwargs, ) if warmup and not manager.warmed_up: # Warmup is long (~30 sec): best to signal the user it's happening than let them think the manager is stuck @@ -1027,8 +1068,6 @@ def generate_batch( progress_bar: If set to true, a progress bar will be displayed persistent_manager: whether to persist the manager after the generation is finished. Default is False. warmup: whether to pre-capture CUDA graphs before processing requests. Default is True. - **kwargs: Additional generation parameters. Only max_new_tokens is used, but other deprecated arguments - are extracted and passed to the continuous_batching_config object. Returns: `dict[str, GenerationOutput]`: a dictionary of request ids to GenerationOutput objects """ @@ -1041,21 +1080,6 @@ def generate_batch( logger.warning("Progress bar is disabled when logger level is less than DEBUG") progress_bar = False - # Extract deprecated arguments from regular kwargs (deprecated in v5.3). These args are now expected in the - # continuous_batching_config object. - deprecated_kwargs = {} - deprecated_keys = [ - "q_padding_interval_size", - "kv_padding_interval_size", - "allow_block_sharing", - "use_async_batching", - "max_cached_graphs", - "max_queue_size", - ] - for depr_key in deprecated_keys: - if depr_key in kwargs: - deprecated_kwargs[depr_key] = kwargs.pop(depr_key) - # Compute the total number of requests gen_cfg = self.generation_config if generation_config is None else generation_config num_return_sequences = gen_cfg.num_return_sequences if gen_cfg.num_return_sequences is not None else 1 @@ -1080,7 +1104,6 @@ def generate_batch( persistent_manager=persistent_manager, warmup=warmup, workload_hints=workload_hints, - **deprecated_kwargs, ) logging_cm = logging_redirect_tqdm([logger]) pbar_cm = tqdm( diff --git a/src/transformers/generation/continuous_batching/distributed.py b/src/transformers/generation/continuous_batching/distributed.py new file mode 100644 index 000000000000..99aa7ba5c322 --- /dev/null +++ b/src/transformers/generation/continuous_batching/distributed.py @@ -0,0 +1,131 @@ +# Copyright 2026 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import TypeVar + +import torch +import torch.distributed as dist +from torch.distributed.tensor.device_mesh import DeviceMesh + +from .requests import logger + + +T = TypeVar("T") + + +class DistributedHelper: + """A helper class to handle distributed-related operations. Notably, it does not crash when distributed is off.""" + + def __init__(self, device_mesh: DeviceMesh | None) -> None: + self.device_mesh = device_mesh + self.dist_on = dist.is_available() and dist.is_initialized() + + # These attributes depend on the global dist state + self.global_rank = dist.get_rank() if self.dist_on else 0 + self.world_size = dist.get_world_size() if self.dist_on else 1 + + # These attributes depend on the TP state + if self.dist_on and self.device_mesh is not None: + self.tp_size = self.device_mesh.size() + self.tp_group = self.device_mesh.get_group() + self.tp_root_global_rank = dist.get_global_rank(self.tp_group, 0) + self.tp_local_rank = self.device_mesh.get_local_rank() + # If TP is on, we create a dedicate CPU group + tp_ranks = dist.get_process_group_ranks(self.tp_group) + self.cpu_comm_group = dist.new_group(ranks=tp_ranks, backend="gloo") + else: + self.tp_size = 1 + self.tp_group = None + self.tp_root_global_rank = 0 + self.tp_local_rank = 0 + self.cpu_comm_group = None + + # The TP driver owns the request queue and scheduler decisions for its TP group. Single-process runs are + # their own driver. + self.is_tp_driver = self.infer_if_tp_driver() + + # These attributes depend on the DP state + self.dp_rank = self.global_rank // self.tp_size + self.dp_size = self.world_size // self.tp_size + + # Accumulator to CPU integer comm + self._cpu_int_acc = torch.tensor([0], dtype=torch.int64, device="cpu") + + def infer_if_tp_driver(self) -> bool: + return self.tp_local_rank == 0 + + def destroy_cpu_comm_group(self) -> None: + """Destroys the CPU comm group.""" + if self.cpu_comm_group is not None: + dist.destroy_process_group(self.cpu_comm_group) + self.cpu_comm_group = None + + def tp_broadcast_from_rank_0(self, value: torch.Tensor) -> torch.Tensor: + """Inside each TP group, broadcasts the given value from rank 0 to all other ranks.""" + if self.tp_size > 1: + dist.broadcast(value, src=self.tp_root_global_rank, async_op=False, group=self.tp_group) + return value + + def tp_broadcast_int(self, value: int) -> int: + """Inside each TP group, broadcasts an integer from rank 0 over the gloo CPU comm group.""" + if self.tp_size > 1: + self._cpu_int_acc[0] = value + dist.broadcast(self._cpu_int_acc, src=self.tp_root_global_rank, async_op=False, group=self.cpu_comm_group) + value = self._cpu_int_acc[0].item() + return value + + def tp_all_reduce_min(self, value: torch.Tensor) -> torch.Tensor: + """Inside each TP group, all-reduces a tensor with the MIN op. No-op when TP is off.""" + if self.tp_size > 1: + dist.all_reduce(value, op=dist.ReduceOp.MIN, group=self.tp_group) + return value + + def tp_broadcast_object(self, obj: T) -> T: + """Inside each TP group, broadcasts an arbitrary picklable Python object from TP-rank 0 to all other ranks. + Used to keep request ingress and cancellations consistent across TP workers without requiring all ranks to + receive the same external request stream. Uses a dedicated CPU (gloo) `cpu_comm_group` for broadcast.""" + if self.tp_size <= 1: + return obj + holder = [obj] if self.is_tp_driver else [None] + dist.broadcast_object_list( + holder, src=self.tp_root_global_rank, group=self.cpu_comm_group, device=torch.device("cpu") + ) + return holder[0] + + def maybe_warn_nccl_graph_mixing(self) -> None: + """Throws a warning if TP is on and NCCL's graph mixing support was supposed to be disabled but isn't. That can + happen if the distributed group is created before graph mixing is disabled. Typically, if the model is + initialized before the ContinuousBatchingConfig is created.""" + tp_on = self.tp_size > 1 + graph_mixing_not_disabled = os.environ.get("NCCL_GRAPH_MIXING_SUPPORT") != "0" + if tp_on and graph_mixing_not_disabled: + logger.warning( + "NCCL_GRAPH_MIXING_SUPPORT was not set to '0' before init_process_group: performance will be harmed. " + "Construct your `ContinuousBatchingConfig(...)` BEFORE calling `from_pretrained(tp_plan='auto')`, or " + "set NCCL_GRAPH_MIXING_SUPPORT=0 in the launch environment." + ) + + def set_tp_seed(self, seed: int | None, model_device: torch.device) -> None: + # Get an integer seed for the TP group + if seed is None: + tp_seed_tensor = torch.randint(0, 2**32 - 1, (1,), dtype=torch.int64, device=model_device) + else: + tp_seed_tensor = torch.tensor(seed, dtype=torch.int64, device=model_device) + # Broadcast the seed to all ranks from rank 0 and memoize it + tp_seed_tensor = self.tp_broadcast_from_rank_0(tp_seed_tensor) + tp_seed = tp_seed_tensor.item() + if self.global_rank == 0 and seed is None: + logger.info(f"Found no user-specified seed in the config. Setting the config seed to: {tp_seed}.") + # Set the seed while accounting for DP replicas + torch.manual_seed(tp_seed + self.dp_rank) diff --git a/src/transformers/generation/continuous_batching/utils.py b/src/transformers/generation/continuous_batching/utils.py index 038daa94812b..df00d1a80842 100644 --- a/src/transformers/generation/continuous_batching/utils.py +++ b/src/transformers/generation/continuous_batching/utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import queue from collections import OrderedDict from contextlib import contextmanager from dataclasses import dataclass @@ -208,6 +209,19 @@ def create_warmup_future_states( return future_states +def drain_queue(request_queue: queue.Queue) -> list[RequestState]: + """Drains a queue and returns a list of RequestStates.""" + new_states: list[RequestState] = [] + while not request_queue.empty(): + try: + state = request_queue.get_nowait() + if state is not None: + new_states.append(state) + except queue.Empty: + break + return new_states + + def get_cuda_pools(): # no type hint because it would make torch 2.4 crash """Returns a tuple of (mem_pool, graph_pool_id) for CUDA graphs. Since the MemPool object is only available in torch 2.5+, we only return a graph_pool_id for older versions.""" diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 1952b5e0fa26..36319008c9ad 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -15,6 +15,7 @@ import functools import gc import itertools +import os import unittest from typing import Any from unittest.mock import patch @@ -40,6 +41,7 @@ ) from transformers.generation.continuous_batching.cache_manager import FullAttentionCacheAllocator from transformers.generation.continuous_batching.continuous_api import OutputRouter +from transformers.generation.continuous_batching.distributed import DistributedHelper from transformers.generation.continuous_batching.input_outputs import build_attention_mask from transformers.generation.continuous_batching.offloading_manager import OffloadingManager from transformers.generation.continuous_batching.requests import GenerationOutput, RequestStatus @@ -50,6 +52,7 @@ require_kernels, require_torch_accelerator, require_torch_gpu, + require_torch_multi_accelerator, slow, torch_device, ) @@ -60,6 +63,8 @@ ) from transformers.utils.generic import is_flash_attention_requested +from ..test_tensor_parallel_mixin import _init_distributed + # Constants for tests _DEFAULT_USER_MESSAGES = [ @@ -351,7 +356,7 @@ def test_continuous_batching_will_allocation_be_successful( num_free_blocks: int, expected_result: bool, ) -> None: - """Test the will_allocation_be_successful method of PagedAttentionCache, overloading the elevant attributes of + """Test the will_allocation_be_successful method of PagedAttentionCache, overloading the relevant attributes of a dummy cache.""" if torch_device is None: # this check which should always pass and helps with type checking @@ -362,6 +367,8 @@ def test_continuous_batching_will_allocation_be_successful( config=AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="sdpa"), continuous_batching_config=ContinuousBatchingConfig(block_size=16, num_blocks=8, max_batch_tokens=8), device=torch_device, + tp_plan={}, + distributed_helper=DistributedHelper(device_mesh=None), ) # Overload cache parameters to match test scenario @@ -486,6 +493,79 @@ def test_output_router_deliver_to_handler(self): loop.call_soon_threadsafe.assert_called_once() self.assertTrue(router.output_queue.empty()) + def test_distributed_helper_no_dist(self) -> None: + """Test that DistributedHelper falls back to a single-rank, TP-driver setup when distributed is not on.""" + helper = DistributedHelper(device_mesh=None) + self.assertFalse(helper.dist_on) + self.assertEqual(helper.global_rank, 0) + self.assertEqual(helper.world_size, 1) + self.assertEqual(helper.tp_size, 1) + self.assertEqual(helper.tp_local_rank, 0) + self.assertEqual(helper.dp_rank, 0) + self.assertEqual(helper.dp_size, 1) + self.assertTrue(helper.is_tp_driver) + self.assertIsNone(helper.tp_group) + self.assertIsNone(helper.cpu_comm_group) + + # Tensor and object broadcasts should be no-ops without a TP group + tensor = torch.tensor([1.0, 2.0]) + self.assertTrue(torch.equal(helper.tp_broadcast_from_rank_0(tensor), tensor)) + obj = {"some_request": "payload"} + self.assertIs(helper.tp_broadcast_object(obj), obj) + + # All-reduce-min should be a no-op without a TP group + reduce_tensor = torch.tensor([7, 3], dtype=torch.int64) + self.assertIs(helper.tp_all_reduce_min(reduce_tensor), reduce_tensor) + self.assertTrue(torch.equal(reduce_tensor, torch.tensor([7, 3], dtype=torch.int64))) + + def test_distributed_helper_set_tp_seed_no_dist(self) -> None: + """Test that set_tp_seed sets a torch seed without distributed initialized, both with and without a user seed.""" + helper = DistributedHelper(device_mesh=None) + + # Explicit seed: torch RNG state must be reproducible across calls + helper.set_tp_seed(seed=42, model_device=torch.device("cpu")) + first = torch.randint(0, 2**31 - 1, (4,)) + helper.set_tp_seed(seed=42, model_device=torch.device("cpu")) + second = torch.randint(0, 2**31 - 1, (4,)) + self.assertTrue(torch.equal(first, second)) + + # No seed: should not raise and should still set a torch seed + helper.set_tp_seed(seed=None, model_device=torch.device("cpu")) + + def test_continuous_batching_config_disables_nccl_graph_mixing(self) -> None: + """Test that ContinuousBatchingConfig sets NCCL_GRAPH_MIXING_SUPPORT=0 only under a distributed launch + (WORLD_SIZE > 1) and respects the disable_nccl_graph_mixing flag.""" + original_nccl = os.environ.pop("NCCL_GRAPH_MIXING_SUPPORT", None) + original_ws = os.environ.pop("WORLD_SIZE", None) + try: + # Single-GPU launch (no WORLD_SIZE): env var is left untouched + ContinuousBatchingConfig() + self.assertNotIn("NCCL_GRAPH_MIXING_SUPPORT", os.environ) + + # Distributed launch (WORLD_SIZE > 1): env var is set to "0" + os.environ["WORLD_SIZE"] = "2" + ContinuousBatchingConfig() + self.assertEqual(os.environ.get("NCCL_GRAPH_MIXING_SUPPORT"), "0") + + # Explicitly disabled flag: env var is left untouched even under a distributed launch + os.environ.pop("NCCL_GRAPH_MIXING_SUPPORT", None) + ContinuousBatchingConfig(disable_nccl_graph_mixing=False) + self.assertNotIn("NCCL_GRAPH_MIXING_SUPPORT", os.environ) + + # setdefault semantics: a pre-existing value is preserved + os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "1" + ContinuousBatchingConfig() + self.assertEqual(os.environ.get("NCCL_GRAPH_MIXING_SUPPORT"), "1") + finally: + if original_nccl is None: + os.environ.pop("NCCL_GRAPH_MIXING_SUPPORT", None) + else: + os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = original_nccl + if original_ws is None: + os.environ.pop("WORLD_SIZE", None) + else: + os.environ["WORLD_SIZE"] = original_ws + @require_torch_accelerator class ContinuousBatchingWithAcceleratorTest(unittest.TestCase): @@ -598,23 +678,40 @@ def _test_continuous_batching_parity( [False, True], ["eager", "sdpa", "flash_attention_2"], [False, True], - [False, True], ) ) ) @slow - def test_continuous_batching_config_combinations( + def test_continuous_batching_config_combinations_no_compile( self, allow_block_sharing: bool, attn_implementation: str, use_cuda_graph: bool, - use_compile: bool, ) -> None: + # Compiling adds a lot of overhead, so it's better not to include here (2*3*2=12 tests because of cross product) model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" continuous_batching_config = ContinuousBatchingConfig( allow_block_sharing=allow_block_sharing, use_cuda_graph=use_cuda_graph, - use_default_compile_configs=use_compile, + use_default_compile_configs=False, + ) + self._test_continuous_batching_parity( + model_id=model_id, + continuous_batching_config=continuous_batching_config, + attn_implementation=attn_implementation, + ) + + @parameterized.expand([("eager", False), ("sdpa", False), ("sdpa", True), ("flash_attention_2", True)]) + @slow + def test_continuous_batching_config_combinations_with_compile( + self, + attn_implementation: str, + use_cuda_graph: bool, + ) -> None: + model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + continuous_batching_config = ContinuousBatchingConfig( + use_cuda_graph=use_cuda_graph, + use_default_compile_configs=True, ) self._test_continuous_batching_parity( model_id=model_id, @@ -627,7 +724,7 @@ def test_continuous_batching_config_combinations( @parameterized.expand( list( itertools.product( - ["TinyLlama/TinyLlama-1.1B-Chat-v1.0", "google/gemma-2-2b-it"], + ["google/gemma-2-2b-it"], [False, True], [False, True], ) @@ -1005,18 +1102,25 @@ def test_num_return_sequences(self, allow_block_sharing: bool) -> None: # Tests to check addtional features of CB do not change its results # # --------------------------------------------------------------------------------------------- # @parameterized.expand( - list( - itertools.product( - ["sdpa", "flash_attention_2", "flash_attention_3"], - [False, True], - [False, True], - ) - ) + [ + # SDPA: basic features or full features + ("sdpa", False, False), + ("sdpa", True, True), + # FA2: full coverage + ("flash_attention_2", False, False), + ("flash_attention_2", False, True), + ("flash_attention_2", True, False), + ("flash_attention_2", True, True), + # FA3: always turn on CUDA graphs + ("flash_attention_3", True, False), + ("flash_attention_3", True, True), + ] ) @slow def test_continuous_batching_async( self, attn_implementation: str, use_cuda_graph: bool, use_compile: bool ) -> None: + # Again, we try to not overly use_compile because it adds a lot of overhead model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" self._test_continuous_batching_parity( model_id=model_id, @@ -1106,11 +1210,11 @@ def test_per_request_logits_processors(self, use_cuda_graph: bool, use_async_bat use_async_batching=use_async_batching, per_request_processors=True, return_logprobs=True, + q_padding_interval_size=16, # allows for exact comparison between CB and regular generation ) manager = model.init_continuous_batching( generation_config=generation_config, continuous_batching_config=continuous_batching_config, - q_padding_interval_size=16, # allows for exact comparison between CB and regular generation ) # Trick to have temperature, top-k, top-p ... without randomness: diable sampling after manager creation @@ -1337,3 +1441,221 @@ def test_memory_prediction( f"CUDA delta ({actual_cuda}) too far from prediction ({predicted}), " f"allowed overhead = {max_cuda_overhead} ({num_allocations} allocs × 512B)", ) + + +# Worker functions for the TP continuous batching tests, spawned through `_init_distributed`. +def _tp_continuous_batching_worker( + rank: int, + model_id: str, + attn_implementation: str, + max_new_tokens: int, + do_sample: bool, + seed: int, + use_cuda_graph: bool, + use_async_batching: bool, +) -> None: + """Loads `model_id` with `tp_plan="auto"`, checks three TP-specific paths in the same process: (a) direct + broadcasts via `DistributedHelper`, (b) per-rank parity of CB-generated tokens via `dist.all_gather_object`, and + (c) reproducibility across two CB runs sharing the same seed. Rank 0 owns all the assertions; the other ranks + only need to participate in the collectives.""" + import torch + import torch.distributed as dist + + from transformers.generation.continuous_batching.distributed import DistributedHelper + + tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") + if not hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token"): + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation=attn_implementation, tp_plan="auto", dtype=torch.float32 + ).eval() + + # Direct broadcast tests: only rank 0's value should propagate to every TP rank + helper = DistributedHelper(device_mesh=model._device_mesh) + + received_obj = helper.tp_broadcast_object({"src_rank": rank}) + assert received_obj == {"src_rank": 0}, f"tp_broadcast_object: rank {rank} got {received_obj}" + + sent_tensor = torch.tensor([float(rank)], device=model.device) + helper.tp_broadcast_from_rank_0(sent_tensor) + assert sent_tensor.item() == 0.0, f"tp_broadcast_from_rank_0: rank {rank} got {sent_tensor.item()}" + + # CB runs: same seed twice, assert reproducibility AND cross-rank parity + user_messages = [ + "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?" + ] + chats = [[{"role": "user", "content": m}] for m in user_messages] + tokenized = [tokenizer.apply_chat_template(chat, add_generation_prompt=True) for chat in chats] + input_ids = [(x if isinstance(x, list) else x["input_ids"]) for x in tokenized] + + cb_config_kwargs = {"use_cuda_graph": use_cuda_graph, "use_async_batching": use_async_batching, "seed": seed} + gen_config = GenerationConfig(do_sample=do_sample, max_new_tokens=max_new_tokens) + first_outputs = model.generate_batch( + inputs=input_ids, + generation_config=gen_config, + continuous_batching_config=ContinuousBatchingConfig(**cb_config_kwargs), + ) + second_outputs = model.generate_batch( + inputs=input_ids, + generation_config=gen_config, + continuous_batching_config=ContinuousBatchingConfig(**cb_config_kwargs), + ) + + # Cross-rank parity: every TP rank must produce the same tokens, otherwise the seed broadcast / TP collectives are + # diverging silently. Gather the first run's tokens onto all ranks and let rank 0 compare. + local_tokens = [out.generated_tokens for out in first_outputs.values()] + gathered_tokens = [None] * helper.tp_size + dist.all_gather_object(gathered_tokens, local_tokens, group=helper.tp_group) + + if rank != 0: + return + + assert len(first_outputs) == len(input_ids), f"Expected {len(input_ids)} CB outputs, got {len(first_outputs)}" + for i, (_, output) in enumerate(first_outputs.items()): + assert len(output.generated_tokens) > 0, f"Request {i} got no generated tokens" + + for src_rank, src_tokens in enumerate(gathered_tokens): + if src_tokens != gathered_tokens[0]: + raise AssertionError( + f"TP continuous batching diverges across ranks: rank {src_rank} got {src_tokens}, rank 0 got " + f"{gathered_tokens[0]}" + ) + + second_tokens = [out.generated_tokens for out in second_outputs.values()] + if local_tokens != second_tokens: + raise AssertionError( + f"TP continuous batching is not reproducible across runs with the same seed\n" + f"First run : {local_tokens}\n" + f"Second run: {second_tokens}" + ) + + +def _tp_cancellation_worker( + rank: int, + model_id: str, + attn_implementation: str, + use_cuda_graph: bool = False, + use_async_batching: bool = False, +) -> None: + """Loads `model_id` with `tp_plan="auto"`, submits a long-running streaming request, and cancels it mid-flight. + The cancellation goes through the cancel-queue + `tp_broadcast_object` path: if the broadcast were broken, the + non-driver rank's scheduler would not learn about the cancellation and the test would hang or crash on the next + TP forward pass. Rank 0 owns the assertions.""" + import time + + import torch + + cb_config = ContinuousBatchingConfig(use_cuda_graph=use_cuda_graph, use_async_batching=use_async_batching) + + tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") + if not hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token"): + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation=attn_implementation, tp_plan="auto", dtype=torch.float32 + ).eval() + + chat = [{"role": "user", "content": "Tell me a long story about a robot exploring the galaxy."}] + tokenized = tokenizer.apply_chat_template(chat, add_generation_prompt=True) + inputs = tokenized if isinstance(tokenized, list) else tokenized["input_ids"] + + max_new_tokens = 200 + cancel_after_n_chunks = 3 + + manager = model.init_continuous_batching(continuous_batching_config=cb_config) + manager.logit_processor.clear() + # Warm up synchronously so CUDA-graph capture doesn't eat the streaming-loop deadline below + manager.warmup() + manager.start() + try: + request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True) + chunks_seen = 0 + cancelled = False + deadline = time.time() + 60 + while time.time() < deadline: + chunk = manager.get_result(request_id=request_id, timeout=2.0) + if chunk is None: + # No new chunks for 2s after cancel — cancellation took effect on every rank + break + chunks_seen += 1 + if chunks_seen >= cancel_after_n_chunks and not cancelled: + manager.cancel_request(request_id) + cancelled = True + if rank == 0: + assert cancelled, "Test setup did not reach the cancel call" + assert chunks_seen < max_new_tokens, ( + f"Cancellation did not stop generation early: saw {chunks_seen} chunks " + f"for max_new_tokens={max_new_tokens}" + ) + finally: + manager.stop(block=True) + + +@require_torch_multi_accelerator +class ContinuousBatchingTensorParallelTest(unittest.TestCase): + """Integration tests for continuous batching with tensor parallelism. Each test spawns a TP-sized process group + via `_init_distributed` (see `tests/test_tensor_parallel_mixin.py`) with the NCCL backend.""" + + @property + def tp_size(self) -> int: + return min(torch.cuda.device_count(), 2) + + def _run_cb_worker(self, **worker_kwargs) -> None: + """Spawn `_tp_continuous_batching_worker` on `tp_size` NCCL processes with sensible defaults.""" + defaults = { + "model_id": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "attn_implementation": "sdpa", + "max_new_tokens": 20, + "do_sample": False, + "seed": 42, + "use_cuda_graph": False, + "use_async_batching": False, + } + defaults.update(worker_kwargs) + _init_distributed(tp=self.tp_size, backend="nccl")(_tp_continuous_batching_worker)(**defaults) + + @slow + def test_continuous_batching_tp_greedy(self) -> None: + """Test that continuous batching with `tp_plan="auto"` produces non-empty, reproducible greedy outputs and + that all TP ranks agree on the generated tokens.""" + self._run_cb_worker() + + @slow + def test_continuous_batching_tp_with_sampling(self) -> None: + """Test that continuous batching with TP and sampling is reproducible across runs with the same seed and that + all TP ranks agree on the sampled tokens — implicitly validating the seed broadcast from rank 0.""" + self._run_cb_worker(do_sample=True, seed=123) + + @slow + def test_continuous_batching_tp_with_cuda_graph(self) -> None: + """Test that continuous batching with TP and CUDA graphs is reproducible across runs and that all TP ranks + agree on the generated tokens — captured-graph collectives must stay in sync across ranks.""" + self._run_cb_worker(use_cuda_graph=True) + + @slow + def test_continuous_batching_tp_with_cuda_graph_and_async(self) -> None: + """Test that continuous batching with TP, CUDA graphs, and async batching is reproducible across runs and + that all TP ranks agree on the generated tokens — the toughest combination, exercising both captured-graph + collectives and the async producer/consumer split.""" + self._run_cb_worker(use_cuda_graph=True, use_async_batching=True) + + @slow + def test_continuous_batching_tp_cancellation(self) -> None: + """Test that `cancel_request` propagates across the TP group: the driver enqueues the cancellation, broadcasts + it to non-driver ranks via `tp_broadcast_object`, and generation stops well before `max_new_tokens`.""" + _init_distributed(tp=self.tp_size, backend="nccl")(_tp_cancellation_worker)( + model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + attn_implementation="sdpa", + ) + + @slow + def test_continuous_batching_tp_cancellation_realistic(self) -> None: + """Test that `cancel_request` propagates across the TP group: the driver enqueues the cancellation, broadcasts + it to non-driver ranks via `tp_broadcast_object`, and generation stops well before `max_new_tokens`.""" + _init_distributed(tp=self.tp_size, backend="nccl")(_tp_cancellation_worker)( + model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + attn_implementation="sdpa", + use_async_batching=True, + use_cuda_graph=True, + ) diff --git a/tests/test_tensor_parallel_mixin.py b/tests/test_tensor_parallel_mixin.py index 68780cf2f8ab..547bce7dacc4 100644 --- a/tests/test_tensor_parallel_mixin.py +++ b/tests/test_tensor_parallel_mixin.py @@ -69,7 +69,7 @@ def get_packed_grad_shard(grad, world_size, rank, dim): return grad.index_select(dim, torch.tensor(indices, device=grad.device)) -def _global_wrapper(rank, func, tp, port, func_args, func_kwargs): +def _global_wrapper(rank, func, tp, port, backend, func_args, func_kwargs): """Wrapper to set up distributed environment and run the test function.""" def setup_dist_env(rank, world_size, port): @@ -82,7 +82,7 @@ def setup_dist_env(rank, world_size, port): world_size = tp setup_dist_env(rank, world_size, port) - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) func(rank, *func_args, **func_kwargs) @@ -90,7 +90,7 @@ def setup_dist_env(rank, world_size, port): dist.destroy_process_group() -def _init_distributed(tp: int, max_retries: int = 5): +def _init_distributed(tp: int, max_retries: int = 5, backend: str = "gloo"): """Decorator to initialize distributed environment and spawn processes.""" def _init_distributed_inner(func): @@ -98,7 +98,7 @@ def wrapper(*args, **kwargs): world_size = tp for attempt in range(max_retries): port = _find_free_port() - spawn_args = (func, tp, port, args, kwargs) + spawn_args = (func, tp, port, backend, args, kwargs) try: mp.spawn(_global_wrapper, args=spawn_args, nprocs=world_size) return