Skip to content

Commit e04ce5e

Browse files
committed
Fixes
1 parent 85d075d commit e04ce5e

4 files changed

Lines changed: 52 additions & 25 deletions

File tree

src/transformers/generation/continuous_batching/cache.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
2525
from .initialization import resolve_max_memory_percent
2626
from .requests import RequestState, RequestStatus, get_device_and_memory_breakdown, logger
27+
from .utils import DistributedHelper
2728

2829

2930
def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]], list[str]]:
@@ -122,8 +123,8 @@ def __init__(
122123
config: PreTrainedConfig,
123124
continuous_batching_config: ContinuousBatchingConfig,
124125
device: torch.device | str,
126+
distributed_helper: DistributedHelper,
125127
dtype: torch.dtype = torch.float16,
126-
tp_size: int | None = None,
127128
) -> None:
128129
"""Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
129130
only full attention layers.
@@ -132,8 +133,8 @@ def __init__(
132133
config: Model configuration
133134
continuous_batching_config: Continuous batching configuration containing cache parameters
134135
device: Device for the cache tensors
136+
distributed_helper: TP-aware helper. Used to dispatch attention heads and ensure coherent cache size
135137
dtype: Data type of the cache
136-
tp_size: Tensor parallelism size
137138
"""
138139
self.config = config
139140
self.dtype = dtype
@@ -165,7 +166,8 @@ def __init__(
165166

166167
# Account for TP: each KV head is dispatched to a different GPU, so the effective number of KV heads per GPU is
167168
# simply divided by the TP size (number of GPUs)
168-
if tp_size is not None and tp_size > 1:
169+
tp_size = distributed_helper.tp_size
170+
if tp_size > 1:
169171
if self.num_key_value_heads % tp_size != 0:
170172
raise ValueError(
171173
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
@@ -214,6 +216,12 @@ def __init__(
214216
cache_dtype=self.dtype,
215217
)
216218

219+
# For TP, align num_blocks and max_batch_tokens to the minimal value across the TP group
220+
if tp_size > 1:
221+
sync = torch.tensor([num_blocks, max_batch_tokens], device=self.device, dtype=torch.int64)
222+
distributed_helper.tp_all_reduce_min(sync)
223+
num_blocks, max_batch_tokens = int(sync[0].item()), int(sync[1].item())
224+
217225
# Add the inferred attributes to the class
218226
self.num_blocks = num_blocks
219227
self.max_batch_tokens = max_batch_tokens

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import torch
2727
from torch import nn
28-
from torch.distributed.tensor.device_mesh import DeviceMesh
2928
from tqdm import tqdm
3029
from tqdm.contrib.logging import logging_redirect_tqdm
3130

@@ -149,7 +148,7 @@ def __init__(
149148
model_device: torch.device,
150149
model_dtype: torch.dtype,
151150
scheduler: Scheduler,
152-
device_mesh: DeviceMesh | None,
151+
distributed_helper: DistributedHelper,
153152
) -> None:
154153
"""Initialize the continuous batch processor.
155154
@@ -166,7 +165,7 @@ def __init__(
166165
model_device: Device for model inputs/outputs
167166
model_dtype: Data type for model inputs/outputs
168167
scheduler: The [`Scheduler`] to use
169-
device_mesh: The device mesh if there is one
168+
distributed_helper: The [`DistributedHelper`] to use
170169
"""
171170
self.cache = cache
172171
self.config = config
@@ -179,7 +178,7 @@ def __init__(
179178
self.model_device = model_device
180179
self.model_dtype = model_dtype
181180
self.scheduler = scheduler
182-
self.distributed_helper = DistributedHelper(device_mesh=device_mesh)
181+
self.distributed_helper = distributed_helper
183182

184183
# Generation-related attributes
185184
self.do_sample = getattr(generation_config, "do_sample", True)
@@ -268,7 +267,7 @@ def _get_new_requests(self) -> None:
268267
payload = (new_states, cancellations)
269268
# Otherwise, the payload is None
270269
else:
271-
payload = None
270+
payload = ([], [])
272271

273272
# Broadcast within the TP group. No-op when tp_size == 1, returns the driver's payload unchanged.
274273
payload = self.distributed_helper.tp_broadcast_object(payload)
@@ -521,11 +520,11 @@ def __init__(
521520
self._request_lock = threading.Lock()
522521

523522
# Infer if this process is the driver of its own TP group
524-
helper = DistributedHelper(device_mesh=getattr(self.model, "_device_mesh", None))
525-
self.is_tp_driver = helper.is_tp_driver
523+
self.distributed_helper = DistributedHelper(device_mesh=getattr(self.model, "_device_mesh", None))
524+
self.is_tp_driver = self.distributed_helper.is_tp_driver
526525
# If TP is on, check if NCCL graph mixing is disabled (helps with performance)
527526
if continuous_batching_config.disable_nccl_graph_mixing:
528-
helper.maybe_warn_nccl_graph_mixing()
527+
self.distributed_helper.maybe_warn_nccl_graph_mixing()
529528

530529
# Generation config related arguments
531530
num_return_sequences = getattr(generation_config, "num_return_sequences", None)
@@ -601,6 +600,7 @@ def stop(self, block: bool = True, timeout: float | None = None, keep_for_next_s
601600
# If the manager is not being kept for next session, we clear the batch processor
602601
if not keep_for_next_session:
603602
self.batch_processor = None
603+
self.distributed_helper.destroy_ingress_group()
604604
# Otherwise, we keep the batch processor and cache the manager as a model attribute
605605
else:
606606
logger.info("Continuous batching manager will be kept for next session.")
@@ -792,15 +792,13 @@ def _generation_step(self) -> None:
792792
self.batch_processor._generation_step(self.model)
793793

794794
def _create_batch_processor(self) -> ContinuousBatchProcessor:
795-
# Retrieve the device mesh if there is one
796-
device_mesh: DeviceMesh | None = getattr(self.model, "_device_mesh", None)
797795
# Create the PagedAttentionCache
798796
paged_attention_cache = PagedAttentionCache(
799-
self.model.config,
800-
self.continuous_batching_config,
801-
self.model.device,
802-
self.model.dtype,
803-
tp_size=DistributedHelper(device_mesh=device_mesh).tp_size, # consistent with the batch processor
797+
config=self.model.config,
798+
continuous_batching_config=self.continuous_batching_config,
799+
device=self.model.device,
800+
distributed_helper=self.distributed_helper,
801+
dtype=self.model.dtype,
804802
)
805803
self._use_prefix_sharing = paged_attention_cache.use_prefix_sharing # update the approximation
806804

@@ -829,7 +827,7 @@ def _create_batch_processor(self) -> ContinuousBatchProcessor:
829827
model_device=self.model.device,
830828
model_dtype=self.model.dtype,
831829
scheduler=scheduler(paged_attention_cache),
832-
device_mesh=device_mesh,
830+
distributed_helper=self.distributed_helper,
833831
)
834832
return batch_processor
835833

src/transformers/generation/continuous_batching/utils.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from collections import OrderedDict
1717
from dataclasses import dataclass
1818
from math import ceil, log2
19-
from typing import Any
19+
from typing import Any, TypeVar
2020

2121
import torch
2222
import torch.distributed as dist
@@ -27,6 +27,9 @@
2727
from .requests import FutureRequestState, RequestState, RequestStatus, logger
2828

2929

30+
T = TypeVar("T")
31+
32+
3033
class CudaGraphBuffer:
3134
"""A fixed-size dict for CUDA graphs with LRU eviction when full."""
3235

@@ -264,13 +267,25 @@ def __init__(self, device_mesh: DeviceMesh | None) -> None:
264267
self.dp_rank = self.global_rank // self.tp_size
265268
self.dp_size = self.world_size // self.tp_size
266269

270+
def destroy_ingress_group(self) -> None:
271+
"""Destroys the ingress group."""
272+
if self.ingress_group is not None:
273+
dist.destroy_process_group(self.ingress_group)
274+
self.ingress_group = None
275+
267276
def tp_broadcast_from_rank_0(self, value: torch.Tensor) -> torch.Tensor:
268277
"""Inside each TP group, broadcasts the given value from rank 0 to all other ranks."""
269278
if self.tp_size > 1:
270279
dist.broadcast(value, src=self.tp_root_global_rank, async_op=False, group=self.tp_group)
271280
return value
272281

273-
def tp_broadcast_object(self, obj):
282+
def tp_all_reduce_min(self, value: torch.Tensor) -> torch.Tensor:
283+
"""Inside each TP group, all-reduces a tensor with the MIN op. No-op when TP is off."""
284+
if self.tp_size > 1:
285+
dist.all_reduce(value, op=dist.ReduceOp.MIN, group=self.tp_group)
286+
return value
287+
288+
def tp_broadcast_object(self, obj: T) -> T:
274289
"""Inside each TP group, broadcasts an arbitrary picklable Python object from TP-rank 0 to all other ranks.
275290
Used to keep request ingress and cancellations consistent across TP workers without requiring all ranks to
276291
receive the same external request stream. Uses a dedicated CPU (gloo) `ingress_group` for broadcast."""
@@ -287,8 +302,8 @@ def maybe_warn_nccl_graph_mixing(self) -> None:
287302
happen if the distributed group is created before graph mixing is disabled. Typically, if the model is
288303
initialized before the ContinousBatchingConfig is created."""
289304
tp_on = self.tp_size > 1
290-
graph_mixing_supported = os.environ.get("NCCL_GRAPH_MIXING_SUPPORT") != "0"
291-
if tp_on or graph_mixing_supported:
305+
graph_mixing_not_disabled = os.environ.get("NCCL_GRAPH_MIXING_SUPPORT") != "0"
306+
if tp_on and graph_mixing_not_disabled:
292307
logger.warning(
293308
"NCCL_GRAPH_MIXING_SUPPORT was not set to '0' before init_process_group: performance will be harmed. "
294309
"Construct your `ContinuousBatchingConfig(...)` BEFORE calling `from_pretrained(tp_plan='auto')`, or "
@@ -304,7 +319,7 @@ def set_tp_seed(self, seed: int | None, model_device: torch.device) -> None:
304319
# Broadcast the seed to all ranks from rank 0 and memoize it
305320
tp_seed_tensor = self.tp_broadcast_from_rank_0(tp_seed_tensor)
306321
tp_seed = tp_seed_tensor.item()
307-
if self.global_rank == 0:
308-
logger.warning(f"Found no user-specified seed in the config. Setting the config seed to: {tp_seed}.")
322+
if self.global_rank == 0 and seed is None:
323+
logger.info(f"Found no user-specified seed in the config. Setting the config seed to: {tp_seed}.")
309324
# Set the seed while accounting for DP replicas
310325
torch.manual_seed(tp_seed + self.dp_rank)

tests/generation/test_continuous_batching.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def test_continuous_batching_will_allocation_be_successful(
367367
config=AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="sdpa"),
368368
continuous_batching_config=ContinuousBatchingConfig(block_size=16, num_blocks=8, max_batch_tokens=8),
369369
device=torch_device,
370+
distributed_helper=DistributedHelper(device_mesh=None),
370371
)
371372

372373
# Overload cache parameters to match test scenario
@@ -511,6 +512,11 @@ def test_distributed_helper_no_dist(self) -> None:
511512
obj = {"some_request": "payload"}
512513
self.assertIs(helper.tp_broadcast_object(obj), obj)
513514

515+
# All-reduce-min should be a no-op without a TP group
516+
reduce_tensor = torch.tensor([7, 3], dtype=torch.int64)
517+
self.assertIs(helper.tp_all_reduce_min(reduce_tensor), reduce_tensor)
518+
self.assertTrue(torch.equal(reduce_tensor, torch.tensor([7, 3], dtype=torch.int64)))
519+
514520
def test_distributed_helper_set_tp_seed_no_dist(self) -> None:
515521
"""Test that set_tp_seed sets a torch seed without distributed initialized, both with and without a user seed."""
516522
helper = DistributedHelper(device_mesh=None)

0 commit comments

Comments
 (0)