@@ -3712,6 +3712,7 @@ def get_attn_backends_for_group(
37123712
37133713 def create_attn_groups (
37143714 attn_backends_map : dict [AttentionGroupKey , list [str ]],
3715+ kv_cache_group_id : int ,
37153716 ) -> list [AttentionGroup ]:
37163717 attn_groups : list [AttentionGroup ] = []
37173718 for (attn_backend , kv_cache_spec ), layer_names in attn_backends_map .items ():
@@ -3721,6 +3722,7 @@ def create_attn_groups(
37213722 kv_cache_spec ,
37223723 self .aphrodite_config ,
37233724 self .device ,
3725+ kv_cache_group_id ,
37243726 num_metadata_builders = 1 if not self .parallel_config .enable_dbo else 2 ,
37253727 )
37263728
@@ -3737,8 +3739,8 @@ def create_attn_groups(
37373739 # Resolve cudagraph_mode before actually initialize metadata_builders
37383740 self ._check_and_update_cudagraph_mode (attention_backend_set )
37393741
3740- for attn_backends_map in attention_backend_maps :
3741- self .attn_groups .append (create_attn_groups (attn_backends_map ))
3742+ for i , attn_backend_map in enumerate ( attention_backend_maps ) :
3743+ self .attn_groups .append (create_attn_groups (attn_backend_map , i ))
37423744
37433745 # Calculate reorder batch threshold (if needed)
37443746 self .calculate_reorder_batch_threshold ()
@@ -3854,97 +3856,89 @@ def calculate_reorder_batch_threshold(self) -> None:
38543856 else :
38553857 self .reorder_batch_threshold = None
38563858
3857- def _find_compatible_block_sizes (
3858- self ,
3859- kv_manager_block_size : int ,
3860- backend_cls : type [AttentionBackend ],
3861- return_all : bool = False ,
3862- ) -> list [int ]:
3863- """
3864- Find compatible block sizes for a backend.
3865-
3866- Args:
3867- kv_manager_block_size: Physical block size of KV cache
3868- backend_cls: Attention backend class
3869- return_all: Return all compatible sizes if True, max size if False
3870-
3871- Returns:
3872- Compatible block size(s) based on return_all parameter
3873-
3874- Raises:
3875- ValueError: If no compatible block size found
3876- """
3877- supported_block_size = backend_cls .get_supported_kernel_block_size ()
3878- compatible_sizes = []
3879-
3880- for block_size in supported_block_size :
3881- if isinstance (block_size , int ):
3882- if kv_manager_block_size % block_size == 0 :
3883- compatible_sizes .append (block_size )
3884- elif isinstance (block_size , MultipleOf ) and kv_manager_block_size % block_size .base == 0 :
3885- compatible_sizes .append (kv_manager_block_size )
3886-
3887- if not compatible_sizes :
3888- raise ValueError (f"No compatible block size for { kv_manager_block_size } " )
3889-
3890- return compatible_sizes if return_all else [max (compatible_sizes )]
3891-
3892- def _select_common_block_size (self , kv_manager_block_size : int , attn_groups : list [AttentionGroup ]) -> int :
3859+ @staticmethod
3860+ def select_common_block_size (kv_manager_block_size : int , attn_groups : list [AttentionGroup ]) -> int :
38933861 """
3894- Select common block size for all backends.
3862+ Select a block size that is supported by all backends and is a factor of
3863+ kv_manager_block_size.
3864+ If kv_manager_block_size is supported by all backends, return it directly.
3865+ Otherwise, return the max supported size.
38953866
38963867 Args:
38973868 kv_manager_block_size: Block size of KV cache
38983869 attn_groups: List of attention groups
38993870
39003871 Returns:
3901- Block size supported by all backends,
3902- prioritizing cache_config.block_size
3872+ The selected block size
39033873
39043874 Raises:
3905- ValueError: If no common block size found
3875+ ValueError: If no valid block size found
39063876 """
3907- all_backend_supports = []
39083877
3909- for attn_group in attn_groups :
3910- compatible_sizes = self ._find_compatible_block_sizes (
3911- kv_manager_block_size , attn_group .backend , return_all = True
3912- )
3913- supported_sizes = sorted (list (set (compatible_sizes )), reverse = True )
3914- all_backend_supports .append (set (supported_sizes ))
3915-
3916- common_supported_sizes = set .intersection (* all_backend_supports )
3917-
3918- if not common_supported_sizes :
3919- error_msg = f"No common block size for { kv_manager_block_size } . "
3920- for i , attn_group in enumerate (attn_groups ):
3921- supported = all_backend_supports [i ]
3922- error_msg += f"Backend { attn_group .backend } supports: { sorted (supported )} . "
3923- raise ValueError (error_msg )
3924-
3925- if self .cache_config .block_size in common_supported_sizes :
3926- return self .cache_config .block_size
3878+ def block_size_is_supported (backends : list [type [AttentionBackend ]], block_size : int ) -> bool :
3879+ """
3880+ Check if the block size is supported by all backends.
3881+ """
3882+ for backend in backends :
3883+ is_supported = False
3884+ for supported_size in backend .get_supported_kernel_block_size ():
3885+ if isinstance (supported_size , int ):
3886+ if block_size == supported_size :
3887+ is_supported = True
3888+ elif isinstance (supported_size , MultipleOf ):
3889+ if block_size % supported_size .base == 0 :
3890+ is_supported = True
3891+ else :
3892+ raise ValueError (f"Unknown supported size: { supported_size } " )
3893+ if not is_supported :
3894+ return False
3895+ return True
3896+
3897+ backends = [group .backend for group in attn_groups ]
3898+
3899+ # Case 1: if the block_size of kv cache manager is supported by all backends,
3900+ # return it directly
3901+ if block_size_is_supported (backends , kv_manager_block_size ):
3902+ return kv_manager_block_size
3903+
3904+ # Case 2: otherwise, the block_size must be an `int`-format supported size of
3905+ # at least one backend. Iterate over all `int`-format supported sizes in
3906+ # descending order and return the first one that is supported by all backends.
3907+ # Simple proof:
3908+ # If the supported size b is in MultipleOf(x_i) format for all attention
3909+ # backends i, and b a factor of kv_manager_block_size, then
3910+ # kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
3911+ # return kv_manager_block_size in case 1.
3912+ all_int_supported_sizes = set (
3913+ supported_size
3914+ for backend in backends
3915+ for supported_size in backend .get_supported_kernel_block_size ()
3916+ if isinstance (supported_size , int )
3917+ )
39273918
3928- return max (common_supported_sizes )
3919+ for supported_size in sorted (all_int_supported_sizes , reverse = True ):
3920+ if kv_manager_block_size % supported_size != 0 :
3921+ continue
3922+ if block_size_is_supported (backends , supported_size ):
3923+ return supported_size
3924+ raise ValueError (f"No common block size for { kv_manager_block_size } . " )
39293925
3930- def may_reinitialize_input_batch (self , kv_cache_config : KVCacheConfig ) -> None :
3926+ def may_reinitialize_input_batch (self , kv_cache_config : KVCacheConfig , kernel_block_sizes : list [ int ] ) -> None :
39313927 """
39323928 Re-initialize the input batch if the block sizes are different from
39333929 `[self.cache_config.block_size]`. This usually happens when there
39343930 are multiple KV cache groups.
39353931
39363932 Args:
39373933 kv_cache_config: The KV cache configuration.
3934+ kernel_block_sizes: The kernel block sizes for each KV cache group.
39383935 """
39393936 block_sizes = [
39403937 kv_cache_group .kv_cache_spec .block_size
39413938 for kv_cache_group in kv_cache_config .kv_cache_groups
39423939 if not isinstance (kv_cache_group .kv_cache_spec , EncoderOnlyAttentionSpec )
39433940 ]
39443941
3945- # Generate kernel_block_sizes that matches each block_size
3946- kernel_block_sizes = self ._prepare_kernel_block_sizes (kv_cache_config )
3947-
39483942 if block_sizes != [self .cache_config .block_size ] or kernel_block_sizes != [self .cache_config .block_size ]:
39493943 assert self .cache_config .cpu_offload_gb == 0 , (
39503944 "Cannot re-initialize the input batch when CPU weight "
@@ -4035,7 +4029,7 @@ def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[in
40354029 # all backends in the group.
40364030 attn_groups = self .attn_groups [kv_cache_group_id ]
40374031 kv_manager_block_size = kv_cache_group .kv_cache_spec .block_size
4038- selected_kernel_size = self ._select_common_block_size (kv_manager_block_size , attn_groups )
4032+ selected_kernel_size = self .select_common_block_size (kv_manager_block_size , attn_groups )
40394033 kernel_block_sizes .append (selected_kernel_size )
40404034 elif isinstance (kv_cache_spec , MambaSpec ):
40414035 # This is likely Mamba or other non-attention cache,
@@ -4049,6 +4043,7 @@ def _reshape_kv_cache_tensors(
40494043 self ,
40504044 kv_cache_config : KVCacheConfig ,
40514045 kv_cache_raw_tensors : dict [str , torch .Tensor ],
4046+ kernel_block_sizes : list [int ],
40524047 ) -> dict [str , torch .Tensor ]:
40534048 """
40544049 Reshape the KV cache tensors to the desired shape and dtype.
@@ -4057,6 +4052,7 @@ def _reshape_kv_cache_tensors(
40574052 kv_cache_config: The KV cache config
40584053 kv_cache_raw_tensors: The KV cache buffer of each layer, with
40594054 correct size but uninitialized shape.
4055+ kernel_block_sizes: The kernel block sizes for each KV cache group.
40604056 Returns:
40614057 Dict[str, torch.Tensor]: A map between layer names to their
40624058 corresponding memory buffer for KV cache.
@@ -4066,6 +4062,10 @@ def _reshape_kv_cache_tensors(
40664062 for group in self ._kv_cache_spec_attn_group_iterator ():
40674063 kv_cache_spec = group .kv_cache_spec
40684064 attn_backend = group .backend
4065+ if group .kv_cache_group_id == len (kernel_block_sizes ):
4066+ # There may be a last group for layers without kv cache.
4067+ continue
4068+ kernel_block_size = kernel_block_sizes [group .kv_cache_group_id ]
40694069 for layer_name in group .layer_names :
40704070 if layer_name in self .runner_only_attn_layers :
40714071 continue
@@ -4074,24 +4074,19 @@ def _reshape_kv_cache_tensors(
40744074 num_blocks = raw_tensor .numel () // kv_cache_spec .page_size_bytes
40754075 if isinstance (kv_cache_spec , AttentionSpec ):
40764076 has_attn = True
4077- kv_manager_block_size = kv_cache_spec .block_size
4078- kernel_size_list = self ._find_compatible_block_sizes (
4079- kv_manager_block_size , attn_backend , return_all = False
4080- )
4081- kernel_size = kernel_size_list [0 ]
4082- num_blocks_per_kv_block = kv_manager_block_size // kernel_size
4077+ num_blocks_per_kv_block = kv_cache_spec .block_size // kernel_block_size
40834078 kernel_num_blocks = num_blocks * num_blocks_per_kv_block
40844079
40854080 kv_cache_shape = attn_backend .get_kv_cache_shape (
40864081 kernel_num_blocks ,
4087- kernel_size ,
4082+ kernel_block_size ,
40884083 kv_cache_spec .num_kv_heads ,
40894084 kv_cache_spec .head_size ,
40904085 cache_dtype_str = self .cache_config .cache_dtype ,
40914086 )
40924087 dtype = kv_cache_spec .dtype
40934088 try :
4094- kv_cache_stride_order = attn_backend .get_kv_cache_stride_order () # noqa: E501
4089+ kv_cache_stride_order = attn_backend .get_kv_cache_stride_order ()
40954090 assert len (kv_cache_stride_order ) == len (kv_cache_shape )
40964091 except (AttributeError , NotImplementedError ):
40974092 kv_cache_stride_order = tuple (range (len (kv_cache_shape )))
@@ -4161,20 +4156,23 @@ def _update_hybrid_attention_mamba_layout(self, kv_caches: dict[str, torch.Tenso
41614156 stride = (hidden_size , 2 * hidden_size , * kv_cache .stride ()[2 :]),
41624157 )
41634158
4164- def initialize_kv_cache_tensors (self , kv_cache_config : KVCacheConfig ) -> dict [str , torch .Tensor ]:
4159+ def initialize_kv_cache_tensors (
4160+ self , kv_cache_config : KVCacheConfig , kernel_block_sizes : list [int ]
4161+ ) -> dict [str , torch .Tensor ]:
41654162 """
41664163 Initialize the memory buffer for KV cache.
41674164
41684165 Args:
41694166 kv_cache_config: The KV cache config
4167+ kernel_block_sizes: The kernel block sizes for each KV cache group.
41704168 Returns:
41714169 Dict[str, torch.Tensor]: A map between layer names to their
41724170 corresponding memory buffer for KV cache.
41734171 """
41744172 # Initialize the memory buffer for KV cache
41754173 kv_cache_raw_tensors = self ._allocate_kv_cache_tensors (kv_cache_config )
41764174 # Change the memory buffer to the desired shape
4177- kv_caches = self ._reshape_kv_cache_tensors (kv_cache_config , kv_cache_raw_tensors )
4175+ kv_caches = self ._reshape_kv_cache_tensors (kv_cache_config , kv_cache_raw_tensors , kernel_block_sizes )
41784176
41794177 if self .is_elastic :
41804178 kv_caches = self ._allocate_kv_cache_from_kvcached (kv_cache_config )
@@ -4231,9 +4229,15 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
42314229 self .may_add_encoder_only_layers_to_kv_cache_config ()
42324230 self .maybe_add_kv_sharing_layers_to_kv_cache_groups (kv_cache_config )
42334231 self .initialize_attn_backend (kv_cache_config )
4232+ # The kernel block size for all KV cache groups. For example, if
4233+ # kv_cache_manager uses block_size 256 for a given group, but the attention
4234+ # backends for that group only supports block_size 64, we will return
4235+ # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
4236+ # tokens each.
4237+ kernel_block_sizes = self ._prepare_kernel_block_sizes (kv_cache_config )
42344238 # Reinitialize need to after initialize_attn_backend
4235- self .may_reinitialize_input_batch (kv_cache_config )
4236- kv_caches = self .initialize_kv_cache_tensors (kv_cache_config )
4239+ self .may_reinitialize_input_batch (kv_cache_config , kernel_block_sizes )
4240+ kv_caches = self .initialize_kv_cache_tensors (kv_cache_config , kernel_block_sizes )
42374241
42384242 if self .speculative_config and self .speculative_config .use_eagle ():
42394243 assert isinstance (self .drafter , EagleProposer )
0 commit comments