Skip to content

Commit e4f5f2f

Browse files
remi-orkhushali9
authored andcommitted
[CB] [Major] Add tensor paralellism (huggingface#45821)
* TP heads and DP / TP seeds * Reproducible hash * Add the notion of TP drivers * Fix NCCL device * Temporary fix for multiple streams * Better handling of NCCL graph mixing * Fix cfg * nit * Move the seed setting * Reworked overall to have accuracy scoring * Adding tests 1/n * Added tests * Style * Fixes * CC review * Nits * Renames * Small fixes * Move distributed stuff to a distributed file * Docstring * Final fixes * Review compliance * Review compliance 2 * Rebase fix * Style * Less redudant testing suite * Fix TP plan * Fix stopping condition * Nits
1 parent 10f6112 commit e4f5f2f

9 files changed

Lines changed: 807 additions & 183 deletions

File tree

benchmark_v2/benchmark_scripts/continuous_batching_overall.py

Lines changed: 178 additions & 58 deletions
Large diffs are not rendered by default.

src/transformers/generation/configuration_utils.py

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,8 @@ class ContinuousBatchingConfig:
16531653
Scheduler type to use.
16541654
return_logprobs (`bool`, *optional*, defaults to `False`):
16551655
Whether to return log probabilities along with the generated tokens.
1656+
seed (`int | None`, *optional*):
1657+
An optional seed for generation. If not specified, the internal seed will be set to a random value.
16561658
cpu_offload_space (`float`, *optional*, defaults to 0.0):
16571659
CPU swap space in GiB for KV cache offloading. A pre-allocated pinned CPU buffer of this size is
16581660
created at initialization. When the GPU cache is full, evicted requests' KV caches are copied here
@@ -1666,6 +1668,8 @@ class ContinuousBatchingConfig:
16661668
Enable per-request logits processor parameters. Default is False.
16671669
drop_unsupported_processors (`bool`, *optional*, defaults to `True`):
16681670
Remove unsupported logits processors instead of erroring. Default is True.
1671+
disable_nccl_graph_mixing (`bool`, *optional*, defaults to `True`):
1672+
Disable NCCL's safety net for parallel graph-captured comms. Never happens in CB and gives TP a perf boost.
16691673
"""
16701674

16711675
# Size of each KV cache block
@@ -1719,6 +1723,9 @@ class ContinuousBatchingConfig:
17191723
# probabilities will be returned along with the generated tokens in the generation output.
17201724
return_logprobs: bool = False
17211725

1726+
# An optional seed for generation. If not specified, the internal seed will be set to a random value.
1727+
seed: int | None = None
1728+
17221729
# CPU swap space in GiB for KV cache offloading. When the GPU cache is full and a request must be evicted, its KV
17231730
# cache is copied to this pre-allocated pinned CPU buffer instead of being discarded. Default to 0.0 GiB. You can
17241731
# also set this to None to dimension the pool using only the safety threshold, but this will error out if psutil is
@@ -1739,44 +1746,18 @@ class ContinuousBatchingConfig:
17391746
# are kept but warnings are logged for unsupported/unknown ones.
17401747
drop_unsupported_processors: bool = True
17411748

1742-
def account_for_cb_deprecated_arguments(
1743-
self,
1744-
max_queue_size: int = 0,
1745-
q_padding_interval_size: int = 0,
1746-
kv_padding_interval_size: int = 0,
1747-
allow_block_sharing: bool = True,
1748-
use_async_batching: bool | None = None,
1749-
max_cached_graphs: int = 0,
1750-
) -> None:
1751-
"""Some arguments given to `generate_batch`, `init_continuous_batching` or `continuous_batching_context_manager`
1752-
are now deprecated and are expected inside the continuous batching config. This method checks if any were
1753-
passed and accounts for them in the continuous batching config. It raises a deprecation warning if any were
1754-
passed.
1755-
"""
1756-
kwargs_to_warn = []
1757-
if max_queue_size > 0:
1758-
kwargs_to_warn.append("max_queue_size")
1759-
self.max_queue_size = max_queue_size
1760-
if q_padding_interval_size > 0:
1761-
kwargs_to_warn.append("q_padding_interval_size")
1762-
self.q_padding_interval_size = q_padding_interval_size
1763-
if kv_padding_interval_size > 0:
1764-
kwargs_to_warn.append("kv_padding_interval_size")
1765-
self.kv_padding_interval_size = kv_padding_interval_size
1766-
if not allow_block_sharing: # config default is True, so False means the user explicitly set it to False
1767-
kwargs_to_warn.append("allow_block_sharing")
1768-
self.allow_block_sharing = allow_block_sharing
1769-
if use_async_batching is not None:
1770-
kwargs_to_warn.append("use_async_batching")
1771-
self.use_async_batching = use_async_batching
1772-
if max_cached_graphs > 0:
1773-
kwargs_to_warn.append("max_cached_graphs")
1774-
self.max_cached_graphs = max_cached_graphs
1775-
if kwargs_to_warn:
1749+
# Disable NCCL's safety net for parallel graph-captured communications. This means it is no longer safe to replay a
1750+
# CUDA graph with NCCL communication at the same time as 1. another CUDA graph with captured comms 2. an eager comm.
1751+
# This is turned on by default because the above never happens in CB and this gives a nice perf boost.
1752+
disable_nccl_graph_mixing: bool = True
1753+
1754+
def __post_init__(self):
1755+
# Only turn off graph mixing support if TP is on
1756+
if self.disable_nccl_graph_mixing and int(os.environ.get("WORLD_SIZE", "1")) > 1:
17761757
logger.warning(
1777-
"The following arguments were provided to a continuous batching entry point instead of being passed "
1778-
"through the continuous_batching_config: " + ", ".join(kwargs_to_warn)
1758+
"Setting NCCL_GRAPH_MIXING_SUPPORT = 0 because disable_nccl_graph_mixing is True and WORLD_SIZE > 1."
17791759
)
1760+
os.environ.setdefault("NCCL_GRAPH_MIXING_SUPPORT", "0")
17801761

17811762
@property
17821763
def cuda_graph_booleans(self) -> tuple[bool, bool]:
@@ -1789,5 +1770,5 @@ def cuda_graph_booleans(self) -> tuple[bool, bool]:
17891770

17901771
@property
17911772
def fallback_max_blocks_per_request(self) -> int:
1792-
"""Returns the max blocks per request."""
1773+
"""Fallback if no user-hint is given and decode path is available."""
17931774
return 32

src/transformers/generation/continuous_batching/cache.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ...utils.generic import is_flash_attention_requested
2323
from ...utils.metrics import attach_tracer, traced
2424
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
25+
from .distributed import DistributedHelper
2526
from .initialization import resolve_max_memory_percent
2627
from .requests import RequestState, RequestStatus, get_device_and_memory_breakdown, logger
2728

@@ -122,8 +123,9 @@ def __init__(
122123
config: PreTrainedConfig,
123124
continuous_batching_config: ContinuousBatchingConfig,
124125
device: torch.device | str,
126+
distributed_helper: DistributedHelper,
127+
tp_plan: dict[str, Any],
125128
dtype: torch.dtype = torch.float16,
126-
tp_size: int | None = None,
127129
) -> None:
128130
"""Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
129131
only full attention layers.
@@ -132,8 +134,9 @@ def __init__(
132134
config: Model configuration
133135
continuous_batching_config: Continuous batching configuration containing cache parameters
134136
device: Device for the cache tensors
137+
distributed_helper: TP-aware helper. Used to dispatch attention heads and ensure coherent cache size
138+
tp_plan: Tensor parallelism plan
135139
dtype: Data type of the cache
136-
tp_size: Tensor parallelism size
137140
"""
138141
self.config = config
139142
self.dtype = dtype
@@ -163,14 +166,23 @@ def __init__(
163166
self.layer_index_to_group_indices[layer] = (i, j)
164167
self.sliding_windows[layer] = sliding_window
165168

166-
# Handle TP (or dont)
167-
if tp_size is not None and tp_size > 1:
169+
# Check if the KV heads are part of the TP plan. If they are not, the cache does not need plan for TP.
170+
# TODO: this is fragile. If your model fails to TP properly because of this, please open an issue.
171+
kv_is_tp = True
172+
for key in ["layers.*.self_attn.k_proj", "layers.*.self_attn.v_proj"]:
173+
if not (key in tp_plan or "model." + key in tp_plan):
174+
kv_is_tp = False
175+
break
176+
177+
# If the KV heads are TP'ed, each KV head is dispatched to a different GPU, so the effective number of KV heads
178+
# per GPU is simply divided by the TP size
179+
tp_size = distributed_helper.tp_size
180+
if tp_size > 1 and kv_is_tp:
168181
if self.num_key_value_heads % tp_size != 0:
169182
raise ValueError(
170183
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
171184
)
172-
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
173-
# self.num_key_value_heads //= tp_size # TODO: why is this commented out?
185+
self.num_key_value_heads //= tp_size
174186

175187
# Infer number of blocks and max batch tokens
176188
page_size = self.head_dim * self.num_key_value_heads
@@ -214,6 +226,12 @@ def __init__(
214226
cache_dtype=self.dtype,
215227
)
216228

229+
# For TP, align num_blocks and max_batch_tokens to the minimal value across the TP group
230+
if tp_size > 1:
231+
sync = torch.tensor([num_blocks, max_batch_tokens], device=self.device, dtype=torch.int64)
232+
distributed_helper.tp_all_reduce_min(sync)
233+
num_blocks, max_batch_tokens = int(sync[0].item()), int(sync[1].item())
234+
217235
# Add the inferred attributes to the class
218236
self.num_blocks = num_blocks
219237
self.max_batch_tokens = max_batch_tokens
@@ -270,7 +288,7 @@ def __init__(
270288

271289
# We only use prefix sharing if the whole model has only full attention layers and block sharing is allowed
272290
self.use_prefix_sharing = self.allow_block_sharing and group_types == ["full_attention"]
273-
self._block_manager = BlockManager(num_blocks, self.block_size)
291+
self._block_manager = BlockManager(num_blocks, self.block_size, tp_on=tp_size > 1)
274292
self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests
275293

276294
# For block table support, we lazy init the name of the block table key

src/transformers/generation/continuous_batching/cache_manager.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import hashlib
1415
from abc import ABC, abstractmethod
16+
from array import array
1517
from collections import deque
1618
from collections.abc import Iterator
1719
from math import ceil
@@ -73,10 +75,11 @@ class BlockManager:
7375
it is in use.
7476
"""
7577

76-
def __init__(self, num_blocks: int, block_size: int) -> None:
78+
def __init__(self, num_blocks: int, block_size: int, tp_on: bool) -> None:
7779
"""Initializes the block manager with a given number of blocks (num_blocks) of size (block_size)."""
7880
self.num_blocks = num_blocks
7981
self.block_size = block_size
82+
self.tp_on = tp_on
8083
self._uninit_block_ids = deque(range(num_blocks))
8184
self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
8285
self._hash_to_id: dict[int, int] = {}
@@ -276,7 +279,19 @@ def mark_shareable_blocks_as_complete(
276279
def compute_hash(self, parent_hash: int | None, tokens: list[int], group_id: int) -> int:
277280
"""Computes the hash of a block identified by the (tokens) it contains, its (parent_hash) and the layer
278281
(group_id) it belong to. If the block has no parent, the parent hash is None."""
279-
return hash((parent_hash, tuple(tokens), group_id))
282+
# If TP is on, we cannot use python `hash` because it depends on the process (it's per-process salted)
283+
# TODO: figure out if this is really a problem. Even if hashes diverge per-process, does that break anything?
284+
if self.tp_on:
285+
h = hashlib.blake2b(digest_size=8)
286+
if parent_hash is not None:
287+
h.update(parent_hash.to_bytes(8, "little", signed=False))
288+
h.update(array("i", tokens).tobytes())
289+
h.update(group_id.to_bytes(4, "little", signed=False))
290+
hash_ = int.from_bytes(h.digest(), "little", signed=False)
291+
# Otherwise, use `hash`
292+
else:
293+
hash_ = hash((parent_hash, tuple(tokens), group_id))
294+
return hash_
280295

281296

282297
class CacheAllocator(ABC):

0 commit comments

Comments
 (0)