Skip to content

Commit af95d2e

Browse files
Update benchmarks
1 parent f16b84f commit af95d2e

4 files changed

Lines changed: 21 additions & 12 deletions

File tree

benchmarks/paged_attention_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"--head-dim",
2727
required=True,
2828
type=int,
29-
default=256,
29+
default=128,
3030
help="Head dimension",
3131
)
3232
@click.option(
@@ -47,14 +47,14 @@
4747
"--batch-size",
4848
required=False,
4949
type=int,
50-
default=4,
50+
default=128,
5151
help="Batch size",
5252
)
5353
@click.option(
5454
"--num-query-heads",
5555
required=False,
5656
type=int,
57-
default=8,
57+
default=32,
5858
help="Number of query heads",
5959
)
6060
@click.option(

benchmarks/paged_attention_vs_flash_benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"--head-dim",
2727
required=True,
2828
type=int,
29-
default=256,
29+
default=128,
3030
help="Head dimension",
3131
)
3232
@click.option(
@@ -47,14 +47,14 @@
4747
"--batch-size",
4848
required=False,
4949
type=int,
50-
default=4,
50+
default=128,
5151
help="Batch size",
5252
)
5353
@click.option(
5454
"--num-query-heads",
5555
required=False,
5656
type=int,
57-
default=8,
57+
default=32,
5858
help="Number of query heads",
5959
)
6060
@click.option(

benchmarks/varlen_attention_benchmark.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@
2727
"--head-dim",
2828
required=True,
2929
type=int,
30-
default=256,
30+
default=128,
3131
help="Head dimension",
3232
)
3333
@click.option(
3434
"--seq-len",
3535
required=True,
3636
type=int,
37-
default=1024,
37+
default=512,
3838
help="Sequence length (for k/v)",
3939
)
4040
@click.option(
@@ -48,21 +48,21 @@
4848
"--batch-size",
4949
required=False,
5050
type=int,
51-
default=10,
51+
default=64,
5252
help="Batch size",
5353
)
5454
@click.option(
5555
"--num-query-heads",
5656
required=False,
5757
type=int,
58-
default=8,
58+
default=32,
5959
help="Number of query heads",
6060
)
6161
@click.option(
6262
"--num-kv-heads",
6363
required=False,
6464
type=int,
65-
default=4,
65+
default=8,
6666
help="Number of kv heads",
6767
)
6868
@click.option(

tools/create_benchmark_results_table.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"GeLU, Tanh, and Mul": "gelu_tanh_and_mul_benchmark",
2525
"SiLU and Mul": "silu_and_mul_benchmark",
2626
"Paged Attention": "paged_attention_vs_flash_benchmark",
27+
"Varlen Attention": "varlen_attention_benchmark",
2728
"Rotary Embedding": "rotary_embedding_benchmark",
2829
"RMS Norm (Gemma-style)": "gemma_rms_norm_benchmark",
2930
"RMS Norm (Llama-style)": "rms_norm_benchmark",
@@ -45,6 +46,11 @@
4546
"unknown": [],
4647
}
4748

49+
# Add any extra flags for each benchmark here
50+
_EXTRA_BENCHMARK_FLAGS: Final = {
51+
"varlen_attention_benchmark": ["--causal"],
52+
}
53+
4854

4955
@click.command()
5056
@click.option(
@@ -90,9 +96,12 @@ def main(results_directory: Path, use_cached_results: bool) -> None:
9096
# Run benchmark and redirect output
9197
print(f"Running benchmark for {op_name}...")
9298

99+
# Some benchmark args are flags to enable things that default false, so we add any per-benchmark here
100+
extra_flags = _EXTRA_BENCHMARK_FLAGS[benchmark_name] if benchmark_name in _EXTRA_BENCHMARK_FLAGS else []
101+
93102
with results_csv.open("w") as results_file:
94103
run(
95-
["python", f"benchmarks/{benchmark_name}.py", "--csv"],
104+
["python", f"benchmarks/{benchmark_name}.py", "--csv"] + extra_flags,
96105
check=True,
97106
stdout=results_file,
98107
env=os.environ,

0 commit comments

Comments
 (0)