66Compatible with A10, H100, AMD MI300X.
77"""
88
9- from typing import Final
10-
119import torch
1210import triton
1311import triton .language as tl
1412from triton .language .extra import libdevice # type: ignore[attr-defined]
1513
1614from conch .platforms import current_platform
1715
18- # The maximum number of stage 1 kernels to launch to split processing of the sequence
19- MAX_NUM_SPLITS : Final = 64
20-
2116
2217@triton .jit # type: ignore[misc]
2318def _paged_attention_compute_splits_kernel ( # noqa: PLR0913, PLR0915
@@ -27,7 +22,7 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
2722 query_ptr : tl .tensor , # (batch_size, num_query_heads, head_size)
2823 key_cache_ptr : tl .tensor , # (num_cache_blocks, cache_block_size, num_kv_heads, head_size)
2924 value_cache_ptr : tl .tensor , # (num_cache_blocks, cache_block_size, num_kv_heads, head_size)
30- block_tables_ptr : tl .tensor , # (batch_size, max_num_blocks_per_sequence)
25+ block_table_ptr : tl .tensor , # (batch_size, max_num_blocks_per_sequence)
3126 seq_lens_ptr : tl .tensor , # (batch_size, )
3227 k_scale_ptr : tl .tensor , # (1,)
3328 v_scale_ptr : tl .tensor , # (1,)
@@ -50,7 +45,7 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
5045 kv_cache_block_stride : int , # key_cache.stride(1), same for key and value
5146 kv_head_stride : int , # key_cache.stride(2), same for key and value
5247 kv_head_element_stride : int , # key_cache.stride(3), same for key and value
53- block_tables_batch_stride : int , # block_tables .stride(0)
48+ block_table_batch_stride : int , # block_table .stride(0)
5449 # Constexprs
5550 cxpr_cache_block_size : tl .constexpr ,
5651 cxpr_head_size_padded : tl .constexpr ,
@@ -67,7 +62,7 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
6762 query_ptr: Pointer to tensor storing the query, shape: (batch_size, num_query_heads, head_size).
6863 key_cache_ptr: Tensor with cached K values, shape: (num_blocks, cache_block_size, num_kv_heads, head_size).
6964 value_cache_ptr: Tensor with cached V values, shape: (num_blocks, cache_block_size, num_kv_heads, head_size).
70- block_tables_ptr : Pointer to tensor storing the mapping from batch to cache blocks, shape: (batch_size, max_num_blocks_per_sequence).
65+ block_table_ptr : Pointer to tensor storing the mapping from batch to cache blocks, shape: (batch_size, max_num_blocks_per_sequence).
7166 seq_lens_ptr: Pointer to tensor holding the current sequence length for each sequence in the batch, shape: (batch_size, ).
7267 k_scale_ptr: Pointer to tensor holding fp8 scaling factor for k.
7368 v_scale_ptr: Pointer to tensor holding fp8 scaling factor for v.
@@ -87,7 +82,7 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
8782 kv_cache_block_stride: Stride of the k/v tensors in the 1st dimension.
8883 kv_head_stride: Stride of the k/v tensors in the 2nd dimension.
8984 kv_head_element_stride: Stride of the k/v tensors in the 3rd dimension.
90- block_tables_batch_stride : Stride of the block table tensor in the 0th dimension.
85+ block_table_batch_stride : Stride of the block table tensor in the 0th dimension.
9186 cxpr_cache_block_size: The size of the cache blocks (must be power of two!).
9287 cxpr_head_size_padded: The head size of the attention layer padded to the next power of two.
9388 cxpr_query_group_size_padded: The query group size padded to the next power of two.
@@ -147,9 +142,9 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
147142 # Offset for the current kv_head in the key_cache and value_cache
148143 kv_head_index_offset = kv_head_index * kv_head_stride
149144
150- # Pointer arithmetic to get to the entry in the block_tables for the current batch_index
151- current_block_table_offset = batch_index * block_tables_batch_stride
152- current_block_table_ptr = block_tables_ptr + current_block_table_offset
145+ # Pointer arithmetic to get to the entry in the block_table for the current batch_index
146+ current_block_table_offset = batch_index * block_table_batch_stride
147+ current_block_table_ptr = block_table_ptr + current_block_table_offset
153148
154149 # Scratchpad for output from this group of cache blocks
155150 output = tl .zeros ([cxpr_query_group_size_padded , cxpr_head_size_padded ], dtype = dtype )
@@ -187,7 +182,7 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
187182 # Load the key block as (cxpr_head_size_padded, cache_block_size)
188183 # Note: we're loading it transposed here
189184 key_block_offsets = (
190- head_offsets [:, None ] + kv_head_index_offset + cache_block_offsets [ None , :] * kv_cache_block_stride
185+ cache_block_offsets [ None , :] * kv_cache_block_stride + kv_head_index_offset + head_offsets [:, None ]
191186 )
192187
193188 key_block_mask = head_mask [:, None ] & cache_block_mask [None , :]
@@ -433,7 +428,7 @@ def paged_attention_launcher( # noqa: PLR0913
433428 value_cache : torch .Tensor ,
434429 output_scratchpad : torch .Tensor ,
435430 lse_scratchpad : torch .Tensor ,
436- block_tables : torch .Tensor ,
431+ block_table : torch .Tensor ,
437432 seq_lens : torch .Tensor ,
438433 scale : float | None = None ,
439434 softcap : float = 0.0 ,
@@ -450,7 +445,7 @@ def paged_attention_launcher( # noqa: PLR0913
450445 value_cache: Tensor with cached V values, shape: (num_blocks, cache_block_size, num_kv_heads, head_size).
451446 output_scratchpad: Tensor used as scratchpad to share cache block outputs between two stages, shape: (batch_size, max_num_blocks_per_sequence, num_query_heads, head_size)
452447 lse_scratchpad: Tensor used as scratchpad to share cache block log-sum-exp between two stages, shape: (batch_size, max_num_blocks_per_sequence, num_query_heads)
453- block_tables : Tensor storing the mapping from batch to cache blocks, shape: (batch_size, max_num_blocks_per_sequence).
448+ block_table : Tensor storing the mapping from batch to cache blocks, shape: (batch_size, max_num_blocks_per_sequence).
454449 seq_lens: Tensor with the sequence length of each index in the batch, shape: (batch_size, ).
455450 scale: Scaling factor, 1/sqrt(head_size).
456451 softcap: Logit softcap to apply (0.0 means no softcap will be applied).
@@ -476,7 +471,8 @@ def paged_attention_launcher( # noqa: PLR0913
476471 # Perform unchecked size accesses, assume has already been checked
477472 batch_size , num_query_heads , head_size = out .shape
478473 num_cache_blocks , cache_block_size , num_kv_heads , _ = key_cache .shape
479- _ , max_num_blocks_per_sequence = block_tables .shape
474+ _ , max_num_blocks_per_sequence = block_table .shape
475+ _ , max_num_splits , _ , _ = output_scratchpad .shape
480476
481477 assert cache_block_size == triton .next_power_of_2 (cache_block_size ), "Cache block size must be a power of two!" # noqa: S101
482478
@@ -497,10 +493,10 @@ def paged_attention_launcher( # noqa: PLR0913
497493
498494 # What is the maximum number of stage 1 kernels to launch per batch/head?
499495 # Each kernel processes up to {cache_block_size} tokens at a time (in many cases cache_block_size=32 for vLLM), so we can process
500- # a sequence up to {MAX_NUM_SPLITS * cache_block_size} == 64 * 32 == 2048 tokens before a stage 1 kernel will process multiple cache
496+ # a sequence up to {max_num_splits * cache_block_size} == 64 * 32 == 2048 tokens before a stage 1 kernel will process multiple cache
501497 # blocks. This helps to reduce the overhead of kernel launches / split reduction for long sequences.
502498 # Note: we may need to tune this value for a given HW platform.
503- num_splits = min (max_num_blocks_per_sequence , MAX_NUM_SPLITS )
499+ num_splits = min (max_num_blocks_per_sequence , max_num_splits )
504500
505501 num_cache_blocks_per_split = triton .cdiv (max_num_blocks_per_sequence , num_splits )
506502
@@ -527,7 +523,7 @@ def paged_attention_launcher( # noqa: PLR0913
527523 query ,
528524 key_cache ,
529525 value_cache ,
530- block_tables ,
526+ block_table ,
531527 seq_lens ,
532528 k_scale ,
533529 v_scale ,
@@ -550,7 +546,7 @@ def paged_attention_launcher( # noqa: PLR0913
550546 key_cache .stride (1 ),
551547 key_cache .stride (2 ),
552548 key_cache .stride (3 ),
553- block_tables .stride (0 ),
549+ block_table .stride (0 ),
554550 # Constexpr sizes
555551 cxpr_cache_block_size ,
556552 cxpr_head_size_padded ,
0 commit comments