Skip to content

Commit f90614e

Browse files
Misc cleanup/performance improvements
1 parent 5eca02d commit f90614e

10 files changed

Lines changed: 287 additions & 218 deletions

benchmarks/paged_attention_benchmark.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,16 @@ def main(
165165
},
166166
)
167167

168-
query, key_cache_vllm, value_cache_vllm, key_cache_conch, value_cache_conch, block_tables, seq_lens = (
169-
create_tensors(
170-
head_dim,
171-
seq_len,
172-
cache_block_size,
173-
batch_size,
174-
num_query_heads,
175-
num_kv_heads,
176-
kv_cache_dtype,
177-
device,
178-
dtype,
179-
)
168+
query, key_cache_vllm, value_cache_vllm, key_cache_conch, value_cache_conch, block_table, seq_lens = create_tensors(
169+
head_dim,
170+
seq_len,
171+
cache_block_size,
172+
batch_size,
173+
num_query_heads,
174+
num_kv_heads,
175+
kv_cache_dtype,
176+
device,
177+
dtype,
180178
)
181179

182180
scale: Final = float(1.0 / (head_dim**0.5))
@@ -191,7 +189,7 @@ def main(
191189
query,
192190
key_cache_conch,
193191
value_cache_conch,
194-
block_tables,
192+
block_table,
195193
seq_lens,
196194
output=output_conch,
197195
scale=scale,
@@ -232,7 +230,7 @@ def main(
232230
value_cache_vllm,
233231
num_kv_heads,
234232
scale,
235-
block_tables,
233+
block_table,
236234
seq_lens,
237235
cache_block_size,
238236
max_seq_len,
@@ -263,7 +261,7 @@ def main(
263261
value_cache_vllm,
264262
num_kv_heads,
265263
scale,
266-
block_tables,
264+
block_table,
267265
seq_lens,
268266
cache_block_size,
269267
max_seq_len,
@@ -287,7 +285,7 @@ def main(
287285
query,
288286
key_cache_conch,
289287
value_cache_conch,
290-
block_tables,
288+
block_table,
291289
seq_lens,
292290
output=output_conch,
293291
scale=scale,

benchmarks/paged_attention_vs_flash_benchmark.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,11 @@ def main(
158158

159159
kv_cache_dtype = "auto"
160160

161-
query, _, _, key_cache, value_cache, block_tables, seq_lens = create_tensors(
161+
query, _, _, key_cache, value_cache, block_table, seq_lens = create_tensors(
162162
head_dim, seq_len, cache_block_size, batch_size, num_query_heads, num_kv_heads, kv_cache_dtype, device, dtype
163163
)
164164

165-
_, max_num_blocks_per_seq = block_tables.shape
165+
_, max_num_blocks_per_seq = block_table.shape
166166

167167
scale: Final = float(1.0 / (head_dim**0.5))
168168

@@ -181,7 +181,7 @@ def main(
181181
query_vllm,
182182
key_cache,
183183
value_cache,
184-
block_table=block_tables,
184+
block_table=block_table,
185185
cache_seqlens=seq_lens,
186186
softmax_scale=scale,
187187
causal=True,
@@ -195,7 +195,7 @@ def main(
195195
query,
196196
key_cache,
197197
value_cache,
198-
block_tables,
198+
block_table,
199199
seq_lens,
200200
output=output_conch,
201201
scale=scale,
@@ -220,7 +220,7 @@ def main(
220220
query_vllm,
221221
key_cache,
222222
value_cache,
223-
block_table=block_tables,
223+
block_table=block_table,
224224
cache_seqlens=seq_lens,
225225
softmax_scale=scale,
226226
causal=True,
@@ -238,7 +238,7 @@ def main(
238238
query,
239239
key_cache,
240240
value_cache,
241-
block_tables,
241+
block_table,
242242
seq_lens,
243243
output=output_conch,
244244
scale=scale,

benchmarks/varlen_attention_benchmark.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def main(
172172
kv_cache_dtype: Final = "auto"
173173
dtype: Final = torch.float16
174174

175-
_, _, _, key_cache, value_cache, block_tables, seq_lens = create_tensors(
175+
_, _, _, key_cache, value_cache, block_table, seq_lens = create_tensors(
176176
head_dim,
177177
seq_len,
178178
cache_block_size,
@@ -215,7 +215,7 @@ def main(
215215
query=query,
216216
key_cache=key_cache,
217217
value_cache=value_cache,
218-
block_tables=block_tables,
218+
block_table=block_table,
219219
seq_lens=seq_lens,
220220
cu_seqlens_q=cu_seqlens_q,
221221
cu_seqlens_k=cu_seqlens_k,
@@ -233,7 +233,7 @@ def main(
233233
cu_seqlens_q=cu_seqlens_q,
234234
max_seqlen_q=max_seqlen_q,
235235
max_seqlen_k=max_seqlen_k,
236-
block_table=block_tables,
236+
block_table=block_table,
237237
seqused_k=seq_lens,
238238
softmax_scale=scale,
239239
causal=causal,
@@ -257,7 +257,7 @@ def main(
257257
cu_seqlens_q=cu_seqlens_q,
258258
max_seqlen_q=max_seqlen_q,
259259
max_seqlen_k=max_seqlen_k,
260-
block_table=block_tables,
260+
block_table=block_table,
261261
seqused_k=seq_lens,
262262
softmax_scale=scale,
263263
causal=causal,
@@ -288,7 +288,7 @@ def main(
288288
softmax_scale=scale,
289289
causal=causal,
290290
window_size=(-1, -1),
291-
block_table=block_tables,
291+
block_table=block_table,
292292
softcap=0.0,
293293
q_descale=None,
294294
k_descale=None,
@@ -308,8 +308,9 @@ def main(
308308
query=query,
309309
key_cache=key_cache,
310310
value_cache=value_cache,
311-
block_tables=block_tables,
311+
block_table=block_table,
312312
seq_lens=seq_lens,
313+
output=output_conch,
313314
cu_seqlens_q=cu_seqlens_q,
314315
cu_seqlens_k=cu_seqlens_k,
315316
max_seqlen_q=max_seqlen_q,

conch/kernels/attention/paged_attention.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,13 @@
66
Compatible with A10, H100, AMD MI300X.
77
"""
88

9-
from typing import Final
10-
119
import torch
1210
import triton
1311
import triton.language as tl
1412
from triton.language.extra import libdevice # type: ignore[attr-defined]
1513

1614
from conch.platforms import current_platform
1715

18-
# The maximum number of stage 1 kernels to launch to split processing of the sequence
19-
MAX_NUM_SPLITS: Final = 64
20-
2116

2217
@triton.jit # type: ignore[misc]
2318
def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
@@ -27,7 +22,7 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
2722
query_ptr: tl.tensor, # (batch_size, num_query_heads, head_size)
2823
key_cache_ptr: tl.tensor, # (num_cache_blocks, cache_block_size, num_kv_heads, head_size)
2924
value_cache_ptr: tl.tensor, # (num_cache_blocks, cache_block_size, num_kv_heads, head_size)
30-
block_tables_ptr: tl.tensor, # (batch_size, max_num_blocks_per_sequence)
25+
block_table_ptr: tl.tensor, # (batch_size, max_num_blocks_per_sequence)
3126
seq_lens_ptr: tl.tensor, # (batch_size, )
3227
k_scale_ptr: tl.tensor, # (1,)
3328
v_scale_ptr: tl.tensor, # (1,)
@@ -50,7 +45,7 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
5045
kv_cache_block_stride: int, # key_cache.stride(1), same for key and value
5146
kv_head_stride: int, # key_cache.stride(2), same for key and value
5247
kv_head_element_stride: int, # key_cache.stride(3), same for key and value
53-
block_tables_batch_stride: int, # block_tables.stride(0)
48+
block_table_batch_stride: int, # block_table.stride(0)
5449
# Constexprs
5550
cxpr_cache_block_size: tl.constexpr,
5651
cxpr_head_size_padded: tl.constexpr,
@@ -67,7 +62,7 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
6762
query_ptr: Pointer to tensor storing the query, shape: (batch_size, num_query_heads, head_size).
6863
key_cache_ptr: Tensor with cached K values, shape: (num_blocks, cache_block_size, num_kv_heads, head_size).
6964
value_cache_ptr: Tensor with cached V values, shape: (num_blocks, cache_block_size, num_kv_heads, head_size).
70-
block_tables_ptr: Pointer to tensor storing the mapping from batch to cache blocks, shape: (batch_size, max_num_blocks_per_sequence).
65+
block_table_ptr: Pointer to tensor storing the mapping from batch to cache blocks, shape: (batch_size, max_num_blocks_per_sequence).
7166
seq_lens_ptr: Pointer to tensor holding the current sequence length for each sequence in the batch, shape: (batch_size, ).
7267
k_scale_ptr: Pointer to tensor holding fp8 scaling factor for k.
7368
v_scale_ptr: Pointer to tensor holding fp8 scaling factor for v.
@@ -87,7 +82,7 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
8782
kv_cache_block_stride: Stride of the k/v tensors in the 1st dimension.
8883
kv_head_stride: Stride of the k/v tensors in the 2nd dimension.
8984
kv_head_element_stride: Stride of the k/v tensors in the 3rd dimension.
90-
block_tables_batch_stride: Stride of the block table tensor in the 0th dimension.
85+
block_table_batch_stride: Stride of the block table tensor in the 0th dimension.
9186
cxpr_cache_block_size: The size of the cache blocks (must be power of two!).
9287
cxpr_head_size_padded: The head size of the attention layer padded to the next power of two.
9388
cxpr_query_group_size_padded: The query group size padded to the next power of two.
@@ -147,9 +142,9 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
147142
# Offset for the current kv_head in the key_cache and value_cache
148143
kv_head_index_offset = kv_head_index * kv_head_stride
149144

150-
# Pointer arithmetic to get to the entry in the block_tables for the current batch_index
151-
current_block_table_offset = batch_index * block_tables_batch_stride
152-
current_block_table_ptr = block_tables_ptr + current_block_table_offset
145+
# Pointer arithmetic to get to the entry in the block_table for the current batch_index
146+
current_block_table_offset = batch_index * block_table_batch_stride
147+
current_block_table_ptr = block_table_ptr + current_block_table_offset
153148

154149
# Scratchpad for output from this group of cache blocks
155150
output = tl.zeros([cxpr_query_group_size_padded, cxpr_head_size_padded], dtype=dtype)
@@ -187,7 +182,7 @@ def _paged_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
187182
# Load the key block as (cxpr_head_size_padded, cache_block_size)
188183
# Note: we're loading it transposed here
189184
key_block_offsets = (
190-
head_offsets[:, None] + kv_head_index_offset + cache_block_offsets[None, :] * kv_cache_block_stride
185+
cache_block_offsets[None, :] * kv_cache_block_stride + kv_head_index_offset + head_offsets[:, None]
191186
)
192187

193188
key_block_mask = head_mask[:, None] & cache_block_mask[None, :]
@@ -433,7 +428,7 @@ def paged_attention_launcher( # noqa: PLR0913
433428
value_cache: torch.Tensor,
434429
output_scratchpad: torch.Tensor,
435430
lse_scratchpad: torch.Tensor,
436-
block_tables: torch.Tensor,
431+
block_table: torch.Tensor,
437432
seq_lens: torch.Tensor,
438433
scale: float | None = None,
439434
softcap: float = 0.0,
@@ -450,7 +445,7 @@ def paged_attention_launcher( # noqa: PLR0913
450445
value_cache: Tensor with cached V values, shape: (num_blocks, cache_block_size, num_kv_heads, head_size).
451446
output_scratchpad: Tensor used as scratchpad to share cache block outputs between two stages, shape: (batch_size, max_num_blocks_per_sequence, num_query_heads, head_size)
452447
lse_scratchpad: Tensor used as scratchpad to share cache block log-sum-exp between two stages, shape: (batch_size, max_num_blocks_per_sequence, num_query_heads)
453-
block_tables: Tensor storing the mapping from batch to cache blocks, shape: (batch_size, max_num_blocks_per_sequence).
448+
block_table: Tensor storing the mapping from batch to cache blocks, shape: (batch_size, max_num_blocks_per_sequence).
454449
seq_lens: Tensor with the sequence length of each index in the batch, shape: (batch_size, ).
455450
scale: Scaling factor, 1/sqrt(head_size).
456451
softcap: Logit softcap to apply (0.0 means no softcap will be applied).
@@ -476,7 +471,8 @@ def paged_attention_launcher( # noqa: PLR0913
476471
# Perform unchecked size accesses, assume has already been checked
477472
batch_size, num_query_heads, head_size = out.shape
478473
num_cache_blocks, cache_block_size, num_kv_heads, _ = key_cache.shape
479-
_, max_num_blocks_per_sequence = block_tables.shape
474+
_, max_num_blocks_per_sequence = block_table.shape
475+
_, max_num_splits, _, _ = output_scratchpad.shape
480476

481477
assert cache_block_size == triton.next_power_of_2(cache_block_size), "Cache block size must be a power of two!" # noqa: S101
482478

@@ -497,10 +493,10 @@ def paged_attention_launcher( # noqa: PLR0913
497493

498494
# What is the maximum number of stage 1 kernels to launch per batch/head?
499495
# Each kernel processes up to {cache_block_size} tokens at a time (in many cases cache_block_size=32 for vLLM), so we can process
500-
# a sequence up to {MAX_NUM_SPLITS * cache_block_size} == 64 * 32 == 2048 tokens before a stage 1 kernel will process multiple cache
496+
# a sequence up to {max_num_splits * cache_block_size} == 64 * 32 == 2048 tokens before a stage 1 kernel will process multiple cache
501497
# blocks. This helps to reduce the overhead of kernel launches / split reduction for long sequences.
502498
# Note: we may need to tune this value for a given HW platform.
503-
num_splits = min(max_num_blocks_per_sequence, MAX_NUM_SPLITS)
499+
num_splits = min(max_num_blocks_per_sequence, max_num_splits)
504500

505501
num_cache_blocks_per_split = triton.cdiv(max_num_blocks_per_sequence, num_splits)
506502

@@ -527,7 +523,7 @@ def paged_attention_launcher( # noqa: PLR0913
527523
query,
528524
key_cache,
529525
value_cache,
530-
block_tables,
526+
block_table,
531527
seq_lens,
532528
k_scale,
533529
v_scale,
@@ -550,7 +546,7 @@ def paged_attention_launcher( # noqa: PLR0913
550546
key_cache.stride(1),
551547
key_cache.stride(2),
552548
key_cache.stride(3),
553-
block_tables.stride(0),
549+
block_table.stride(0),
554550
# Constexpr sizes
555551
cxpr_cache_block_size,
556552
cxpr_head_size_padded,

0 commit comments

Comments
 (0)