Skip to content

Commit 2b230ea

Browse files
authored
[hybrid] simpler algorithm to find kernel_block_size (#1581)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent 4b1e501 commit 2b230ea

3 files changed

Lines changed: 139 additions & 80 deletions

File tree

aphrodite/v1/worker/gpu_model_runner.py

Lines changed: 83 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3712,6 +3712,7 @@ def get_attn_backends_for_group(
37123712

37133713
def create_attn_groups(
37143714
attn_backends_map: dict[AttentionGroupKey, list[str]],
3715+
kv_cache_group_id: int,
37153716
) -> list[AttentionGroup]:
37163717
attn_groups: list[AttentionGroup] = []
37173718
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
@@ -3721,6 +3722,7 @@ def create_attn_groups(
37213722
kv_cache_spec,
37223723
self.aphrodite_config,
37233724
self.device,
3725+
kv_cache_group_id,
37243726
num_metadata_builders=1 if not self.parallel_config.enable_dbo else 2,
37253727
)
37263728

@@ -3737,8 +3739,8 @@ def create_attn_groups(
37373739
# Resolve cudagraph_mode before actually initialize metadata_builders
37383740
self._check_and_update_cudagraph_mode(attention_backend_set)
37393741

3740-
for attn_backends_map in attention_backend_maps:
3741-
self.attn_groups.append(create_attn_groups(attn_backends_map))
3742+
for i, attn_backend_map in enumerate(attention_backend_maps):
3743+
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
37423744

37433745
# Calculate reorder batch threshold (if needed)
37443746
self.calculate_reorder_batch_threshold()
@@ -3854,97 +3856,89 @@ def calculate_reorder_batch_threshold(self) -> None:
38543856
else:
38553857
self.reorder_batch_threshold = None
38563858

3857-
def _find_compatible_block_sizes(
3858-
self,
3859-
kv_manager_block_size: int,
3860-
backend_cls: type[AttentionBackend],
3861-
return_all: bool = False,
3862-
) -> list[int]:
3863-
"""
3864-
Find compatible block sizes for a backend.
3865-
3866-
Args:
3867-
kv_manager_block_size: Physical block size of KV cache
3868-
backend_cls: Attention backend class
3869-
return_all: Return all compatible sizes if True, max size if False
3870-
3871-
Returns:
3872-
Compatible block size(s) based on return_all parameter
3873-
3874-
Raises:
3875-
ValueError: If no compatible block size found
3876-
"""
3877-
supported_block_size = backend_cls.get_supported_kernel_block_size()
3878-
compatible_sizes = []
3879-
3880-
for block_size in supported_block_size:
3881-
if isinstance(block_size, int):
3882-
if kv_manager_block_size % block_size == 0:
3883-
compatible_sizes.append(block_size)
3884-
elif isinstance(block_size, MultipleOf) and kv_manager_block_size % block_size.base == 0:
3885-
compatible_sizes.append(kv_manager_block_size)
3886-
3887-
if not compatible_sizes:
3888-
raise ValueError(f"No compatible block size for {kv_manager_block_size}")
3889-
3890-
return compatible_sizes if return_all else [max(compatible_sizes)]
3891-
3892-
def _select_common_block_size(self, kv_manager_block_size: int, attn_groups: list[AttentionGroup]) -> int:
3859+
@staticmethod
3860+
def select_common_block_size(kv_manager_block_size: int, attn_groups: list[AttentionGroup]) -> int:
38933861
"""
3894-
Select common block size for all backends.
3862+
Select a block size that is supported by all backends and is a factor of
3863+
kv_manager_block_size.
3864+
If kv_manager_block_size is supported by all backends, return it directly.
3865+
Otherwise, return the max supported size.
38953866
38963867
Args:
38973868
kv_manager_block_size: Block size of KV cache
38983869
attn_groups: List of attention groups
38993870
39003871
Returns:
3901-
Block size supported by all backends,
3902-
prioritizing cache_config.block_size
3872+
The selected block size
39033873
39043874
Raises:
3905-
ValueError: If no common block size found
3875+
ValueError: If no valid block size found
39063876
"""
3907-
all_backend_supports = []
39083877

3909-
for attn_group in attn_groups:
3910-
compatible_sizes = self._find_compatible_block_sizes(
3911-
kv_manager_block_size, attn_group.backend, return_all=True
3912-
)
3913-
supported_sizes = sorted(list(set(compatible_sizes)), reverse=True)
3914-
all_backend_supports.append(set(supported_sizes))
3915-
3916-
common_supported_sizes = set.intersection(*all_backend_supports)
3917-
3918-
if not common_supported_sizes:
3919-
error_msg = f"No common block size for {kv_manager_block_size}. "
3920-
for i, attn_group in enumerate(attn_groups):
3921-
supported = all_backend_supports[i]
3922-
error_msg += f"Backend {attn_group.backend} supports: {sorted(supported)}. "
3923-
raise ValueError(error_msg)
3924-
3925-
if self.cache_config.block_size in common_supported_sizes:
3926-
return self.cache_config.block_size
3878+
def block_size_is_supported(backends: list[type[AttentionBackend]], block_size: int) -> bool:
3879+
"""
3880+
Check if the block size is supported by all backends.
3881+
"""
3882+
for backend in backends:
3883+
is_supported = False
3884+
for supported_size in backend.get_supported_kernel_block_size():
3885+
if isinstance(supported_size, int):
3886+
if block_size == supported_size:
3887+
is_supported = True
3888+
elif isinstance(supported_size, MultipleOf):
3889+
if block_size % supported_size.base == 0:
3890+
is_supported = True
3891+
else:
3892+
raise ValueError(f"Unknown supported size: {supported_size}")
3893+
if not is_supported:
3894+
return False
3895+
return True
3896+
3897+
backends = [group.backend for group in attn_groups]
3898+
3899+
# Case 1: if the block_size of kv cache manager is supported by all backends,
3900+
# return it directly
3901+
if block_size_is_supported(backends, kv_manager_block_size):
3902+
return kv_manager_block_size
3903+
3904+
# Case 2: otherwise, the block_size must be an `int`-format supported size of
3905+
# at least one backend. Iterate over all `int`-format supported sizes in
3906+
# descending order and return the first one that is supported by all backends.
3907+
# Simple proof:
3908+
# If the supported size b is in MultipleOf(x_i) format for all attention
3909+
# backends i, and b a factor of kv_manager_block_size, then
3910+
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
3911+
# return kv_manager_block_size in case 1.
3912+
all_int_supported_sizes = set(
3913+
supported_size
3914+
for backend in backends
3915+
for supported_size in backend.get_supported_kernel_block_size()
3916+
if isinstance(supported_size, int)
3917+
)
39273918

3928-
return max(common_supported_sizes)
3919+
for supported_size in sorted(all_int_supported_sizes, reverse=True):
3920+
if kv_manager_block_size % supported_size != 0:
3921+
continue
3922+
if block_size_is_supported(backends, supported_size):
3923+
return supported_size
3924+
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
39293925

3930-
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
3926+
def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]) -> None:
39313927
"""
39323928
Re-initialize the input batch if the block sizes are different from
39333929
`[self.cache_config.block_size]`. This usually happens when there
39343930
are multiple KV cache groups.
39353931
39363932
Args:
39373933
kv_cache_config: The KV cache configuration.
3934+
kernel_block_sizes: The kernel block sizes for each KV cache group.
39383935
"""
39393936
block_sizes = [
39403937
kv_cache_group.kv_cache_spec.block_size
39413938
for kv_cache_group in kv_cache_config.kv_cache_groups
39423939
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
39433940
]
39443941

3945-
# Generate kernel_block_sizes that matches each block_size
3946-
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
3947-
39483942
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [self.cache_config.block_size]:
39493943
assert self.cache_config.cpu_offload_gb == 0, (
39503944
"Cannot re-initialize the input batch when CPU weight "
@@ -4035,7 +4029,7 @@ def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[in
40354029
# all backends in the group.
40364030
attn_groups = self.attn_groups[kv_cache_group_id]
40374031
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
4038-
selected_kernel_size = self._select_common_block_size(kv_manager_block_size, attn_groups)
4032+
selected_kernel_size = self.select_common_block_size(kv_manager_block_size, attn_groups)
40394033
kernel_block_sizes.append(selected_kernel_size)
40404034
elif isinstance(kv_cache_spec, MambaSpec):
40414035
# This is likely Mamba or other non-attention cache,
@@ -4049,6 +4043,7 @@ def _reshape_kv_cache_tensors(
40494043
self,
40504044
kv_cache_config: KVCacheConfig,
40514045
kv_cache_raw_tensors: dict[str, torch.Tensor],
4046+
kernel_block_sizes: list[int],
40524047
) -> dict[str, torch.Tensor]:
40534048
"""
40544049
Reshape the KV cache tensors to the desired shape and dtype.
@@ -4057,6 +4052,7 @@ def _reshape_kv_cache_tensors(
40574052
kv_cache_config: The KV cache config
40584053
kv_cache_raw_tensors: The KV cache buffer of each layer, with
40594054
correct size but uninitialized shape.
4055+
kernel_block_sizes: The kernel block sizes for each KV cache group.
40604056
Returns:
40614057
Dict[str, torch.Tensor]: A map between layer names to their
40624058
corresponding memory buffer for KV cache.
@@ -4066,6 +4062,10 @@ def _reshape_kv_cache_tensors(
40664062
for group in self._kv_cache_spec_attn_group_iterator():
40674063
kv_cache_spec = group.kv_cache_spec
40684064
attn_backend = group.backend
4065+
if group.kv_cache_group_id == len(kernel_block_sizes):
4066+
# There may be a last group for layers without kv cache.
4067+
continue
4068+
kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
40694069
for layer_name in group.layer_names:
40704070
if layer_name in self.runner_only_attn_layers:
40714071
continue
@@ -4074,24 +4074,19 @@ def _reshape_kv_cache_tensors(
40744074
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
40754075
if isinstance(kv_cache_spec, AttentionSpec):
40764076
has_attn = True
4077-
kv_manager_block_size = kv_cache_spec.block_size
4078-
kernel_size_list = self._find_compatible_block_sizes(
4079-
kv_manager_block_size, attn_backend, return_all=False
4080-
)
4081-
kernel_size = kernel_size_list[0]
4082-
num_blocks_per_kv_block = kv_manager_block_size // kernel_size
4077+
num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size
40834078
kernel_num_blocks = num_blocks * num_blocks_per_kv_block
40844079

40854080
kv_cache_shape = attn_backend.get_kv_cache_shape(
40864081
kernel_num_blocks,
4087-
kernel_size,
4082+
kernel_block_size,
40884083
kv_cache_spec.num_kv_heads,
40894084
kv_cache_spec.head_size,
40904085
cache_dtype_str=self.cache_config.cache_dtype,
40914086
)
40924087
dtype = kv_cache_spec.dtype
40934088
try:
4094-
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() # noqa: E501
4089+
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
40954090
assert len(kv_cache_stride_order) == len(kv_cache_shape)
40964091
except (AttributeError, NotImplementedError):
40974092
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
@@ -4161,20 +4156,23 @@ def _update_hybrid_attention_mamba_layout(self, kv_caches: dict[str, torch.Tenso
41614156
stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]),
41624157
)
41634158

4164-
def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
4159+
def initialize_kv_cache_tensors(
4160+
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
4161+
) -> dict[str, torch.Tensor]:
41654162
"""
41664163
Initialize the memory buffer for KV cache.
41674164
41684165
Args:
41694166
kv_cache_config: The KV cache config
4167+
kernel_block_sizes: The kernel block sizes for each KV cache group.
41704168
Returns:
41714169
Dict[str, torch.Tensor]: A map between layer names to their
41724170
corresponding memory buffer for KV cache.
41734171
"""
41744172
# Initialize the memory buffer for KV cache
41754173
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
41764174
# Change the memory buffer to the desired shape
4177-
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors)
4175+
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes)
41784176

41794177
if self.is_elastic:
41804178
kv_caches = self._allocate_kv_cache_from_kvcached(kv_cache_config)
@@ -4231,9 +4229,15 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
42314229
self.may_add_encoder_only_layers_to_kv_cache_config()
42324230
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
42334231
self.initialize_attn_backend(kv_cache_config)
4232+
# The kernel block size for all KV cache groups. For example, if
4233+
# kv_cache_manager uses block_size 256 for a given group, but the attention
4234+
# backends for that group only supports block_size 64, we will return
4235+
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
4236+
# tokens each.
4237+
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
42344238
# Reinitialize need to after initialize_attn_backend
4235-
self.may_reinitialize_input_batch(kv_cache_config)
4236-
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
4239+
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
4240+
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config, kernel_block_sizes)
42374241

42384242
if self.speculative_config and self.speculative_config.use_eagle():
42394243
assert isinstance(self.drafter, EagleProposer)

aphrodite/v1/worker/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class AttentionGroup:
138138
metadata_builders: list[AttentionMetadataBuilder]
139139
layer_names: list[str]
140140
kv_cache_spec: KVCacheSpec
141+
kv_cache_group_id: int
141142

142143
@staticmethod
143144
def create_with_metadata_builders(
@@ -146,13 +147,14 @@ def create_with_metadata_builders(
146147
kv_cache_spec: KVCacheSpec,
147148
aphrodite_config: AphroditeConfig,
148149
device: torch.device,
150+
kv_cache_group_id: int,
149151
num_metadata_builders: int = 1,
150152
) -> "AttentionGroup":
151153
metadata_builders = [
152154
backend.get_builder_cls()(kv_cache_spec, layer_names, aphrodite_config, device)
153155
for _ in range(num_metadata_builders)
154156
]
155-
return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec)
157+
return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id)
156158

157159
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
158160
assert len(self.metadata_builders) > ubatch_id

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
from aphrodite.attention import Attention
6+
from aphrodite.attention.backends.abstract import MultipleOf
67
from aphrodite.common.sampling_params import SamplingParams
78
from aphrodite.config import (
89
AphroditeConfig,
@@ -23,6 +24,7 @@
2324
from aphrodite.v1.sample.metadata import SamplingMetadata
2425
from aphrodite.v1.worker.gpu_input_batch import InputBatch
2526
from aphrodite.v1.worker.gpu_model_runner import GPUModelRunner
27+
from aphrodite.v1.worker.utils import AttentionGroup
2628

2729
BLOCK_SIZE = 16
2830
NUM_BLOCKS = 10
@@ -160,6 +162,57 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
160162
return (block_table.block_table.np[req_index, :num_blocks] == req_state.block_ids[0]).all()
161163

162164

165+
def _make_mock_backend_for_kernel_block_size(
166+
supported_sizes: list[int | MultipleOf],
167+
):
168+
class _MockBackend:
169+
@staticmethod
170+
def get_supported_kernel_block_size():
171+
return supported_sizes
172+
173+
return _MockBackend()
174+
175+
176+
def _make_kv_cache_spec() -> FullAttentionSpec:
177+
return FullAttentionSpec(block_size=1, num_kv_heads=1, head_size=1, dtype="float16")
178+
179+
180+
def test_select_common_block_size_prefers_manager_block_size():
181+
backend_a = _make_mock_backend_for_kernel_block_size([MultipleOf(32)])
182+
backend_b = _make_mock_backend_for_kernel_block_size([64, MultipleOf(16)])
183+
attn_groups = [
184+
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
185+
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
186+
]
187+
188+
selected_size = GPUModelRunner.select_common_block_size(128, attn_groups)
189+
assert selected_size == 128
190+
191+
192+
def test_select_common_block_size_uses_largest_shared_int():
193+
backend_a = _make_mock_backend_for_kernel_block_size([128, 64])
194+
backend_b = _make_mock_backend_for_kernel_block_size([64, 32])
195+
attn_groups = [
196+
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
197+
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
198+
]
199+
200+
selected_size = GPUModelRunner.select_common_block_size(256, attn_groups)
201+
assert selected_size == 64
202+
203+
204+
def test_select_common_block_size_no_valid_option():
205+
backend_a = _make_mock_backend_for_kernel_block_size([64])
206+
backend_b = _make_mock_backend_for_kernel_block_size([MultipleOf(16)])
207+
attn_groups = [
208+
AttentionGroup(backend_a, [], [], _make_kv_cache_spec(), 0),
209+
AttentionGroup(backend_b, [], [], _make_kv_cache_spec(), 0),
210+
]
211+
212+
with pytest.raises(ValueError):
213+
GPUModelRunner.select_common_block_size(48, attn_groups)
214+
215+
163216
def test_update_states_new_request(model_runner, dist_init):
164217
req_id = "req_0"
165218

0 commit comments

Comments
 (0)