Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 178 additions & 58 deletions benchmark_v2/benchmark_scripts/continuous_batching_overall.py

Large diffs are not rendered by default.

55 changes: 18 additions & 37 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,6 +1653,8 @@ class ContinuousBatchingConfig:
Scheduler type to use.
return_logprobs (`bool`, *optional*, defaults to `False`):
Whether to return log probabilities along with the generated tokens.
seed (`int | None`, *optional*):
An optional seed for generation. If not specified, the internal seed will be set to a random value.
cpu_offload_space (`float`, *optional*, defaults to 0.0):
CPU swap space in GiB for KV cache offloading. A pre-allocated pinned CPU buffer of this size is
created at initialization. When the GPU cache is full, evicted requests' KV caches are copied here
Expand All @@ -1666,6 +1668,8 @@ class ContinuousBatchingConfig:
Enable per-request logits processor parameters. Default is False.
drop_unsupported_processors (`bool`, *optional*, defaults to `True`):
Remove unsupported logits processors instead of erroring. Default is True.
disable_nccl_graph_mixing (`bool`, *optional*, defaults to `True`):
Disable NCCL's safety net for parallel graph-captured comms. Never happens in CB and gives TP a perf boost.
"""

# Size of each KV cache block
Expand Down Expand Up @@ -1719,6 +1723,9 @@ class ContinuousBatchingConfig:
# probabilities will be returned along with the generated tokens in the generation output.
return_logprobs: bool = False

# An optional seed for generation. If not specified, the internal seed will be set to a random value.
seed: int | None = None

# CPU swap space in GiB for KV cache offloading. When the GPU cache is full and a request must be evicted, its KV
# cache is copied to this pre-allocated pinned CPU buffer instead of being discarded. Default to 0.0 GiB. You can
# also set this to None to dimension the pool using only the safety threshold, but this will error out if psutil is
Expand All @@ -1739,44 +1746,18 @@ class ContinuousBatchingConfig:
# are kept but warnings are logged for unsupported/unknown ones.
drop_unsupported_processors: bool = True

def account_for_cb_deprecated_arguments(
self,
max_queue_size: int = 0,
q_padding_interval_size: int = 0,
kv_padding_interval_size: int = 0,
allow_block_sharing: bool = True,
use_async_batching: bool | None = None,
max_cached_graphs: int = 0,
) -> None:
"""Some arguments given to `generate_batch`, `init_continuous_batching` or `continuous_batching_context_manager`
are now deprecated and are expected inside the continuous batching config. This method checks if any were
passed and accounts for them in the continuous batching config. It raises a deprecation warning if any were
passed.
"""
kwargs_to_warn = []
if max_queue_size > 0:
kwargs_to_warn.append("max_queue_size")
self.max_queue_size = max_queue_size
if q_padding_interval_size > 0:
kwargs_to_warn.append("q_padding_interval_size")
self.q_padding_interval_size = q_padding_interval_size
if kv_padding_interval_size > 0:
kwargs_to_warn.append("kv_padding_interval_size")
self.kv_padding_interval_size = kv_padding_interval_size
if not allow_block_sharing: # config default is True, so False means the user explicitly set it to False
kwargs_to_warn.append("allow_block_sharing")
self.allow_block_sharing = allow_block_sharing
if use_async_batching is not None:
kwargs_to_warn.append("use_async_batching")
self.use_async_batching = use_async_batching
if max_cached_graphs > 0:
kwargs_to_warn.append("max_cached_graphs")
self.max_cached_graphs = max_cached_graphs
if kwargs_to_warn:
# Disable NCCL's safety net for parallel graph-captured communications. This means it is no longer safe to replay a
# CUDA graph with NCCL communication at the same time as 1. another CUDA graph with captured comms 2. an eager comm.
# This is turned on by default because the above never happens in CB and this gives a nice perf boost.
disable_nccl_graph_mixing: bool = True

def __post_init__(self):
# Only turn off graph mixing support if TP is on
if self.disable_nccl_graph_mixing and int(os.environ.get("WORLD_SIZE", "1")) > 1:
logger.warning(
"The following arguments were provided to a continuous batching entry point instead of being passed "
"through the continuous_batching_config: " + ", ".join(kwargs_to_warn)
"Setting NCCL_GRAPH_MIXING_SUPPORT = 0 because disable_nccl_graph_mixing is True and WORLD_SIZE > 1."
)
os.environ.setdefault("NCCL_GRAPH_MIXING_SUPPORT", "0")

@property
def cuda_graph_booleans(self) -> tuple[bool, bool]:
Expand All @@ -1789,5 +1770,5 @@ def cuda_graph_booleans(self) -> tuple[bool, bool]:

@property
def fallback_max_blocks_per_request(self) -> int:
"""Returns the max blocks per request."""
"""Fallback if no user-hint is given and decode path is available."""
return 32
32 changes: 25 additions & 7 deletions src/transformers/generation/continuous_batching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ...utils.generic import is_flash_attention_requested
from ...utils.metrics import attach_tracer, traced
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
from .distributed import DistributedHelper
from .initialization import resolve_max_memory_percent
from .requests import RequestState, RequestStatus, get_device_and_memory_breakdown, logger

Expand Down Expand Up @@ -122,8 +123,9 @@ def __init__(
config: PreTrainedConfig,
continuous_batching_config: ContinuousBatchingConfig,
device: torch.device | str,
distributed_helper: DistributedHelper,
tp_plan: dict[str, Any],
dtype: torch.dtype = torch.float16,
tp_size: int | None = None,
) -> None:
"""Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
only full attention layers.
Expand All @@ -132,8 +134,9 @@ def __init__(
config: Model configuration
continuous_batching_config: Continuous batching configuration containing cache parameters
device: Device for the cache tensors
distributed_helper: TP-aware helper. Used to dispatch attention heads and ensure coherent cache size
tp_plan: Tensor parallelism plan
dtype: Data type of the cache
tp_size: Tensor parallelism size
"""
self.config = config
self.dtype = dtype
Expand Down Expand Up @@ -163,14 +166,23 @@ def __init__(
self.layer_index_to_group_indices[layer] = (i, j)
self.sliding_windows[layer] = sliding_window

# Handle TP (or dont)
if tp_size is not None and tp_size > 1:
# Check if the KV heads are part of the TP plan. If they are not, the cache does not need plan for TP.
# TODO: this is fragile. If your model fails to TP properly because of this, please open an issue.
kv_is_tp = True
for key in ["layers.*.self_attn.k_proj", "layers.*.self_attn.v_proj"]:
if not (key in tp_plan or "model." + key in tp_plan):
kv_is_tp = False
break

# If the KV heads are TP'ed, each KV head is dispatched to a different GPU, so the effective number of KV heads
# per GPU is simply divided by the TP size
tp_size = distributed_helper.tp_size
if tp_size > 1 and kv_is_tp:
if self.num_key_value_heads % tp_size != 0:
raise ValueError(
f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
)
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
# self.num_key_value_heads //= tp_size # TODO: why is this commented out?
self.num_key_value_heads //= tp_size

# Infer number of blocks and max batch tokens
page_size = self.head_dim * self.num_key_value_heads
Expand Down Expand Up @@ -214,6 +226,12 @@ def __init__(
cache_dtype=self.dtype,
)

# For TP, align num_blocks and max_batch_tokens to the minimal value across the TP group
if tp_size > 1:
sync = torch.tensor([num_blocks, max_batch_tokens], device=self.device, dtype=torch.int64)
distributed_helper.tp_all_reduce_min(sync)
num_blocks, max_batch_tokens = int(sync[0].item()), int(sync[1].item())

# Add the inferred attributes to the class
self.num_blocks = num_blocks
self.max_batch_tokens = max_batch_tokens
Expand Down Expand Up @@ -270,7 +288,7 @@ def __init__(

# We only use prefix sharing if the whole model has only full attention layers and block sharing is allowed
self.use_prefix_sharing = self.allow_block_sharing and group_types == ["full_attention"]
self._block_manager = BlockManager(num_blocks, self.block_size)
self._block_manager = BlockManager(num_blocks, self.block_size, tp_on=tp_size > 1)
self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests

# For block table support, we lazy init the name of the block table key
Expand Down
19 changes: 17 additions & 2 deletions src/transformers/generation/continuous_batching/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
from abc import ABC, abstractmethod
from array import array
from collections import deque
from collections.abc import Iterator
from math import ceil
Expand Down Expand Up @@ -73,10 +75,11 @@ class BlockManager:
it is in use.
"""

def __init__(self, num_blocks: int, block_size: int) -> None:
def __init__(self, num_blocks: int, block_size: int, tp_on: bool) -> None:
"""Initializes the block manager with a given number of blocks (num_blocks) of size (block_size)."""
self.num_blocks = num_blocks
self.block_size = block_size
self.tp_on = tp_on
self._uninit_block_ids = deque(range(num_blocks))
self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
self._hash_to_id: dict[int, int] = {}
Expand Down Expand Up @@ -276,7 +279,19 @@ def mark_shareable_blocks_as_complete(
def compute_hash(self, parent_hash: int | None, tokens: list[int], group_id: int) -> int:
"""Computes the hash of a block identified by the (tokens) it contains, its (parent_hash) and the layer
(group_id) it belong to. If the block has no parent, the parent hash is None."""
return hash((parent_hash, tuple(tokens), group_id))
# If TP is on, we cannot use python `hash` because it depends on the process (it's per-process salted)
# TODO: figure out if this is really a problem. Even if hashes diverge per-process, does that break anything?
if self.tp_on:
h = hashlib.blake2b(digest_size=8)
if parent_hash is not None:
h.update(parent_hash.to_bytes(8, "little", signed=False))
h.update(array("i", tokens).tobytes())
h.update(group_id.to_bytes(4, "little", signed=False))
hash_ = int.from_bytes(h.digest(), "little", signed=False)
# Otherwise, use `hash`
else:
hash_ = hash((parent_hash, tuple(tokens), group_id))
return hash_


class CacheAllocator(ABC):
Expand Down
Loading
Loading