Skip to content

Commit bed5f68

Browse files
Merge pull request #26 from stackav-oss/feature/jmanning/update-benchmarks
Speedup varlen (again)
2 parents f16b84f + 33bf58a commit bed5f68

5 files changed

Lines changed: 30 additions & 98 deletions

File tree

benchmarks/paged_attention_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"--head-dim",
2727
required=True,
2828
type=int,
29-
default=256,
29+
default=128,
3030
help="Head dimension",
3131
)
3232
@click.option(
@@ -47,14 +47,14 @@
4747
"--batch-size",
4848
required=False,
4949
type=int,
50-
default=4,
50+
default=128,
5151
help="Batch size",
5252
)
5353
@click.option(
5454
"--num-query-heads",
5555
required=False,
5656
type=int,
57-
default=8,
57+
default=32,
5858
help="Number of query heads",
5959
)
6060
@click.option(

benchmarks/paged_attention_vs_flash_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"--head-dim",
2727
required=True,
2828
type=int,
29-
default=256,
29+
default=128,
3030
help="Head dimension",
3131
)
3232
@click.option(
@@ -47,14 +47,14 @@
4747
"--batch-size",
4848
required=False,
4949
type=int,
50-
default=4,
50+
default=128,
5151
help="Batch size",
5252
)
5353
@click.option(
5454
"--num-query-heads",
5555
required=False,
5656
type=int,
57-
default=8,
57+
default=32,
5858
help="Number of query heads",
5959
)
6060
@click.option(

benchmarks/varlen_attention_benchmark.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@
2727
"--head-dim",
2828
required=True,
2929
type=int,
30-
default=256,
30+
default=128,
3131
help="Head dimension",
3232
)
3333
@click.option(
3434
"--seq-len",
3535
required=True,
3636
type=int,
37-
default=1024,
37+
default=512,
3838
help="Sequence length (for k/v)",
3939
)
4040
@click.option(
@@ -48,21 +48,21 @@
4848
"--batch-size",
4949
required=False,
5050
type=int,
51-
default=10,
51+
default=64,
5252
help="Batch size",
5353
)
5454
@click.option(
5555
"--num-query-heads",
5656
required=False,
5757
type=int,
58-
default=8,
58+
default=32,
5959
help="Number of query heads",
6060
)
6161
@click.option(
6262
"--num-kv-heads",
6363
required=False,
6464
type=int,
65-
default=4,
65+
default=8,
6666
help="Number of kv heads",
6767
)
6868
@click.option(

conch/kernels/attention/varlen_attention.py

Lines changed: 9 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -13,70 +13,10 @@
1313
import triton.language as tl
1414
from triton.language.extra import libdevice # type: ignore[attr-defined]
1515

16-
# Note: adding `bool` or `str` type annotations to these load/store helper functions doesn't work
17-
1816
# FP8 representation on CUDA and ROCm
1917
_FP8_DTYPES: Final = [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
2018

2119

22-
@triton.jit # type: ignore[misc]
23-
def _load_2d_block_ptr( # type: ignore[no-untyped-def]
24-
data_ptr: tl.tensor,
25-
mask_first_dim,
26-
mask_second_dim,
27-
padding_option,
28-
) -> tl.tensor:
29-
"""Load a 2D tensor with custom strides and offsets."""
30-
if mask_first_dim and mask_second_dim:
31-
# Load with boundary check on both dimensions
32-
data = tl.load(data_ptr, boundary_check=(0, 1), padding_option=padding_option)
33-
elif mask_first_dim:
34-
# Load with boundary check on first dimension only
35-
data = tl.load(data_ptr, boundary_check=(0,), padding_option=padding_option)
36-
elif mask_second_dim:
37-
# Load with boundary check on second dimension only
38-
data = tl.load(data_ptr, boundary_check=(1,), padding_option=padding_option)
39-
else:
40-
# Load without boundary check
41-
data = tl.load(data_ptr)
42-
43-
return data
44-
45-
46-
@triton.jit # type: ignore[misc]
47-
def _load( # type: ignore[no-untyped-def]
48-
data_ptr: tl.tensor,
49-
use_mask,
50-
mask: tl.tensor,
51-
other,
52-
) -> tl.tensor:
53-
"""Load a 1D tensor with custom strides and offsets."""
54-
if use_mask:
55-
# Load with mask
56-
data = tl.load(data_ptr, mask=mask, other=other)
57-
else:
58-
# Load without mask
59-
data = tl.load(data_ptr)
60-
61-
return data
62-
63-
64-
@triton.jit # type: ignore[misc]
65-
def _store( # type: ignore[no-untyped-def]
66-
data_ptr: tl.tensor,
67-
value: tl.tensor,
68-
use_mask,
69-
mask: tl.tensor,
70-
) -> None:
71-
"""Store a 1D tensor with custom strides and offsets."""
72-
if use_mask:
73-
# Store with mask
74-
tl.store(data_ptr, value, mask=mask)
75-
else:
76-
# Store without mask
77-
tl.store(data_ptr, value)
78-
79-
8020
@triton.jit # type: ignore[misc]
8121
def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
8222
# Pointers to tensors
@@ -261,15 +201,11 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
261201
# Mask out query elements that are just for padding
262202
query_mask = query_split_group_seq_mask[:, None] & query_split_group_head_mask[:, None] & head_mask[None, :]
263203

264-
# Determine whether or not we need masking for different dimensions
265-
needs_query_split_mask = end_seqlen_q > this_query_length
266-
needs_query_group_mask = query_group_size != cxpr_query_group_size_padded
267-
needs_head_mask = head_size != cxpr_head_size_padded
268-
needs_query_mask = (needs_query_split_mask or needs_query_group_mask) or needs_head_mask
204+
# Only need causal masking if enabled and this program is not processing a decode
269205
needs_causal_mask = cxpr_is_causal and not is_pure_decode
270206

271207
# Load queries
272-
query = _load(query_ptr + query_offsets, use_mask=needs_query_mask, mask=query_mask, other=0.0)
208+
query = tl.load(query_ptr + query_offsets, mask=query_mask, other=0.0)
273209

274210
if cxpr_apply_fp8_scaling:
275211
q_scale = tl.load(q_scale_ptr)
@@ -301,9 +237,6 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
301237
cache_block_offsets = tl.arange(0, cxpr_cache_block_size)
302238
cache_block_mask = cache_block_offsets < num_entries_in_cache_block
303239

304-
needs_cache_block_mask = num_entries_in_cache_block != cxpr_cache_block_size
305-
needs_qk_mask = (needs_query_split_mask or needs_query_group_mask) or needs_cache_block_mask
306-
307240
# Offset from the block_table row for the current batch by the number of cache blocks
308241
current_cache_block_number_ptr = current_block_table_ptr + cache_block_index
309242
physical_cache_block_number = tl.load(current_cache_block_number_ptr)
@@ -319,9 +252,8 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
319252

320253
key_block_mask = head_mask[:, None] & cache_block_mask[None, :]
321254

322-
key_block = _load(
255+
key_block = tl.load(
323256
key_cache_ptr + kv_cache_block_index_offset + key_block_offsets,
324-
use_mask=(needs_cache_block_mask or needs_head_mask),
325257
mask=key_block_mask,
326258
other=0.0,
327259
)
@@ -345,9 +277,7 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
345277
causal_mask = query_split_group_seq_offsets[:, None] >= effective_seqlen_k_offsets[None, :]
346278
qk_mask = qk_mask & causal_mask
347279

348-
if needs_qk_mask or needs_causal_mask:
349-
# Set masked out elements to -inf
350-
qk = tl.where(qk_mask, qk, -float("inf")).to(dtype)
280+
qk = tl.where(qk_mask, qk, -float("inf")).to(dtype)
351281

352282
# Handle softcapping
353283
if cxpr_is_softcap:
@@ -378,9 +308,8 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
378308

379309
value_block_mask = cache_block_mask[:, None] & head_mask[None, :]
380310

381-
value_block = _load(
311+
value_block = tl.load(
382312
value_cache_ptr + kv_cache_block_index_offset + value_block_offsets,
383-
use_mask=(needs_cache_block_mask or needs_head_mask),
384313
mask=value_block_mask,
385314
other=0.0,
386315
)
@@ -415,10 +344,9 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
415344
)
416345

417346
# Store output scratchpad results
418-
_store(
347+
tl.store(
419348
output_scratchpad_ptr + output_scratch_offsets,
420349
output,
421-
use_mask=needs_query_mask,
422350
mask=query_mask,
423351
)
424352

@@ -439,10 +367,9 @@ def _varlen_attention_compute_splits_kernel( # noqa: PLR0913, PLR0915
439367
lse_mask = query_split_group_seq_mask & query_split_group_head_mask
440368

441369
# Store lse scratchpad results
442-
_store(
370+
tl.store(
443371
lse_scratchpad_ptr + lse_scratch_offsets,
444372
lse,
445-
use_mask=(needs_query_split_mask or needs_query_group_mask),
446373
mask=lse_mask,
447374
)
448375

@@ -522,8 +449,6 @@ def _varlen_attention_reduce_splits_kernel( # noqa: PLR0913
522449
head_offsets = tl.arange(0, cxpr_head_size_padded)
523450
# Mask to only read valid indices of the actual head size
524451
head_mask = head_offsets < head_size
525-
# Only enable masking if its necessary
526-
needs_head_mask = head_size != cxpr_head_size_padded
527452

528453
# Iterate through every cache block for the current sequence
529454
for kv_split_index in range(num_kv_splits_this_seq):
@@ -537,9 +462,8 @@ def _varlen_attention_reduce_splits_kernel( # noqa: PLR0913
537462
)
538463

539464
# Load output for this cache block, shape -> (cxpr_head_size_padded,)
540-
block_output = _load(
465+
block_output = tl.load(
541466
output_scratchpad_ptr + output_scratchpad_offsets,
542-
use_mask=needs_head_mask,
543467
mask=head_mask,
544468
other=0.0,
545469
)
@@ -583,10 +507,9 @@ def _varlen_attention_reduce_splits_kernel( # noqa: PLR0913
583507
output_offsets = batch_index * output_batch_stride + query_head_index * output_head_stride + head_offsets
584508

585509
# Store final result
586-
_store(
510+
tl.store(
587511
output_ptr + output_offsets,
588512
output,
589-
use_mask=needs_head_mask,
590513
mask=head_mask,
591514
)
592515

tools/create_benchmark_results_table.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"GeLU, Tanh, and Mul": "gelu_tanh_and_mul_benchmark",
2525
"SiLU and Mul": "silu_and_mul_benchmark",
2626
"Paged Attention": "paged_attention_vs_flash_benchmark",
27+
"Varlen Attention": "varlen_attention_benchmark",
2728
"Rotary Embedding": "rotary_embedding_benchmark",
2829
"RMS Norm (Gemma-style)": "gemma_rms_norm_benchmark",
2930
"RMS Norm (Llama-style)": "rms_norm_benchmark",
@@ -45,6 +46,11 @@
4546
"unknown": [],
4647
}
4748

49+
# Add any extra flags for each benchmark here
50+
_EXTRA_BENCHMARK_FLAGS: Final = {
51+
"varlen_attention_benchmark": ["--causal"],
52+
}
53+
4854

4955
@click.command()
5056
@click.option(
@@ -90,9 +96,12 @@ def main(results_directory: Path, use_cached_results: bool) -> None:
9096
# Run benchmark and redirect output
9197
print(f"Running benchmark for {op_name}...")
9298

99+
# Some benchmark args are flags to enable things that default false, so we add any per-benchmark here
100+
extra_flags = _EXTRA_BENCHMARK_FLAGS[benchmark_name] if benchmark_name in _EXTRA_BENCHMARK_FLAGS else []
101+
93102
with results_csv.open("w") as results_file:
94103
run(
95-
["python", f"benchmarks/{benchmark_name}.py", "--csv"],
104+
["python", f"benchmarks/{benchmark_name}.py", "--csv"] + extra_flags,
96105
check=True,
97106
stdout=results_file,
98107
env=os.environ,

0 commit comments

Comments
 (0)