From 2bd6e067b4c4874b8e7426b3f56fc11d75e7b428 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 6 Apr 2026 08:47:40 -0700 Subject: [PATCH] Add benchmarks for transposed vs standard KV cache layout MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Benchmarks comparing transposed [B, H, S, D] vs standard [B, S, H, D] KV cache layouts in custom_sdpa and update_cache ops using Llama 3 8B config (32 Q heads, 8 KV heads, D=128). Both C++ (Google Benchmark) and Python benchmarks are included, covering decode (seq_len=1) at various cache fill levels and prefill scenarios. Results on Apple M-series show transposed cache significantly improves SDPA performance at longer cache fills (1.64x at start_pos=1024, 1.13x for prefill seq_len=512) due to better memory locality in the attn_score @ V GEMM — V stride along S_kv changes from H*D to D. Authored with Claude. Differential Revision: [D99677680](https://our.internmc.facebook.com/intern/diff/D99677680/) [ghstack-poisoned] --- extension/llm/custom_ops/BUCK | 26 ++ .../llm/custom_ops/bench_transposed_cache.cpp | 372 ++++++++++++++++++ .../llm/custom_ops/bench_transposed_cache.py | 343 ++++++++++++++++ 3 files changed, 741 insertions(+) create mode 100644 extension/llm/custom_ops/bench_transposed_cache.cpp create mode 100644 extension/llm/custom_ops/bench_transposed_cache.py diff --git a/extension/llm/custom_ops/BUCK b/extension/llm/custom_ops/BUCK index 3b20c25bd2d..3ff3769c0f7 100644 --- a/extension/llm/custom_ops/BUCK +++ b/extension/llm/custom_ops/BUCK @@ -105,6 +105,21 @@ fbcode_target(_kind = runtime.python_test, ], ) +fbcode_target(_kind = runtime.python_binary, + name = "bench_transposed_cache_py", + srcs = [ + "bench_transposed_cache.py", + ], + main_function = "executorch.extension.llm.custom_ops.bench_transposed_cache.main", + preload_deps = [ + ":custom_ops_aot_lib_mkl_noomp", + ":custom_ops_aot_py", + ], + deps = [ + "//caffe2:torch", + ], +) + fbcode_target(_kind = cpp_benchmark, name = "bench_sdpa", srcs = ["bench_sdpa.cpp"], @@ -117,3 +132,14 @@ fbcode_target(_kind = cpp_benchmark, "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], ) + +cpp_benchmark( + name = "bench_transposed_cache", + srcs = ["bench_transposed_cache.cpp"], + deps = [ + "fbsource//third-party/benchmark:benchmark", + "//executorch/extension/llm/custom_ops:custom_ops_mkl_noomp", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ], +) diff --git a/extension/llm/custom_ops/bench_transposed_cache.cpp b/extension/llm/custom_ops/bench_transposed_cache.cpp new file mode 100644 index 00000000000..1c669592d96 --- /dev/null +++ b/extension/llm/custom_ops/bench_transposed_cache.cpp @@ -0,0 +1,372 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +/* + * Benchmark to compare performance of transposed vs standard KV cache layout + * for custom_sdpa and update_cache ops. + * + * Standard layout: [Batch, Seq, Heads, HeadDim] (is_seq_dim_2=false) + * Transposed layout: [Batch, Heads, Seq, HeadDim] (is_seq_dim_2=true) + * + * The hypothesis is that transposed cache improves GEMM performance because: + * - In attn_score @ V: V stride along S_kv changes from H*D to D + * - In Q @ K^T: K stride similarly improves from H*D to D + */ + +#include + +#include +#include +#include + +#include +#include +#include + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::testing::TensorFactory; + +namespace { + +// Fill a float tensor with random data in [0, 1) +void fill_random(Tensor& t, std::mt19937& gen) { + std::uniform_real_distribution dist(0.0f, 1.0f); + float* data = t.mutable_data_ptr(); + for (int64_t i = 0; i < t.numel(); ++i) { + data[i] = dist(gen); + } +} + +} // namespace + +// Benchmark fixture that sets up tensors for SDPA benchmarking. +// Uses std::optional because ExecuTorch Tensor has a deleted default ctor. +class SDPABenchFixture : public benchmark::Fixture { + public: + // Args: {batch, num_heads_q, num_heads_kv, head_dim, max_seq_len, start_pos, + // query_seq_len, is_transposed} + void SetUp(benchmark::State& state) override { + int64_t batch = state.range(0); + int64_t num_heads_q = state.range(1); + int64_t num_heads_kv = state.range(2); + int64_t head_dim = state.range(3); + int64_t max_seq_len = state.range(4); + int64_t start_pos = state.range(5); + int64_t q_seq_len = state.range(6); + bool is_transposed = state.range(7) != 0; + + std::mt19937 gen(42); + + if (is_transposed) { + // [B, H, S, D] + q_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_q, + (int32_t)q_seq_len, (int32_t)head_dim})); + k_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_kv, + (int32_t)max_seq_len, (int32_t)head_dim})); + v_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_kv, + (int32_t)max_seq_len, (int32_t)head_dim})); + output_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_q, + (int32_t)q_seq_len, (int32_t)head_dim})); + } else { + // [B, S, H, D] + q_.emplace(tf_.zeros({(int32_t)batch, (int32_t)q_seq_len, + (int32_t)num_heads_q, (int32_t)head_dim})); + k_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)max_seq_len, + (int32_t)num_heads_kv, (int32_t)head_dim})); + v_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)max_seq_len, + (int32_t)num_heads_kv, (int32_t)head_dim})); + output_.emplace(tf_.zeros({(int32_t)batch, (int32_t)q_seq_len, + (int32_t)num_heads_q, (int32_t)head_dim})); + } + + fill_random(*q_, gen); + fill_random(*k_cache_, gen); + fill_random(*v_cache_, gen); + + start_pos_ = start_pos; + is_transposed_ = is_transposed; + } + + void TearDown(benchmark::State&) override { + q_.reset(); + k_cache_.reset(); + v_cache_.reset(); + output_.reset(); + } + + TensorFactory tf_; + std::optional q_; + std::optional k_cache_; + std::optional v_cache_; + std::optional output_; + int64_t start_pos_ = 0; + bool is_transposed_ = false; +}; + +// Benchmark custom_sdpa with causal masking +BENCHMARK_DEFINE_F(SDPABenchFixture, CustomSDPA) +(benchmark::State& state) { + for (auto _ : state) { + KernelRuntimeContext ctx{}; + torch::executor::native::custom_sdpa_out( + ctx, + *q_, + *k_cache_, + *v_cache_, + start_pos_, + std::nullopt, // attn_mask + 0.0, // dropout_p + true, // is_causal + std::nullopt, // scale + is_transposed_, + *output_); + } +} + +// Benchmark fixture for update_cache +class UpdateCacheBenchFixture : public benchmark::Fixture { + public: + // Args: {batch, num_heads, head_dim, max_seq_len, start_pos, + // update_seq_len, is_transposed} + void SetUp(benchmark::State& state) override { + int64_t batch = state.range(0); + int64_t num_heads = state.range(1); + int64_t head_dim = state.range(2); + int64_t max_seq_len = state.range(3); + int64_t start_pos = state.range(4); + int64_t update_seq_len = state.range(5); + bool is_transposed = state.range(6) != 0; + + std::mt19937 gen(42); + + if (is_transposed) { + // [B, H, S, D] + value_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads, + (int32_t)update_seq_len, (int32_t)head_dim})); + cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads, + (int32_t)max_seq_len, (int32_t)head_dim})); + } else { + // [B, S, H, D] + value_.emplace(tf_.zeros({(int32_t)batch, (int32_t)update_seq_len, + (int32_t)num_heads, (int32_t)head_dim})); + cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)max_seq_len, + (int32_t)num_heads, (int32_t)head_dim})); + } + + fill_random(*value_, gen); + fill_random(*cache_, gen); + // Output is a dummy placeholder (unused by update_cache_out) + update_output_.emplace(tf_.zeros({1})); + + start_pos_ = start_pos; + is_transposed_ = is_transposed; + } + + void TearDown(benchmark::State&) override { + value_.reset(); + cache_.reset(); + update_output_.reset(); + } + + TensorFactory tf_; + std::optional value_; + std::optional cache_; + std::optional update_output_; + int64_t start_pos_ = 0; + bool is_transposed_ = false; +}; + +// Benchmark update_cache +BENCHMARK_DEFINE_F(UpdateCacheBenchFixture, UpdateCache) +(benchmark::State& state) { + for (auto _ : state) { + KernelRuntimeContext ctx{}; + torch::executor::native::update_cache_out( + ctx, + *value_, + *cache_, + start_pos_, + is_transposed_, + *update_output_); + } +} + +// Combined update_cache + custom_sdpa fixture +class CombinedBenchFixture : public benchmark::Fixture { + public: + // Args: {batch, num_heads_q, num_heads_kv, head_dim, max_seq_len, start_pos, + // seq_len, is_transposed} + void SetUp(benchmark::State& state) override { + int64_t batch = state.range(0); + int64_t num_heads_q = state.range(1); + int64_t num_heads_kv = state.range(2); + int64_t head_dim = state.range(3); + int64_t max_seq_len = state.range(4); + int64_t start_pos = state.range(5); + int64_t seq_len = state.range(6); + bool is_transposed = state.range(7) != 0; + + std::mt19937 gen(42); + + if (is_transposed) { + // [B, H, S, D] + q_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_q, + (int32_t)seq_len, (int32_t)head_dim})); + k_proj_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_kv, + (int32_t)seq_len, (int32_t)head_dim})); + v_proj_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_kv, + (int32_t)seq_len, (int32_t)head_dim})); + k_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_kv, + (int32_t)max_seq_len, (int32_t)head_dim})); + v_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_kv, + (int32_t)max_seq_len, (int32_t)head_dim})); + output_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_q, + (int32_t)seq_len, (int32_t)head_dim})); + } else { + // [B, S, H, D] + q_.emplace(tf_.zeros({(int32_t)batch, (int32_t)seq_len, + (int32_t)num_heads_q, (int32_t)head_dim})); + k_proj_.emplace(tf_.zeros({(int32_t)batch, (int32_t)seq_len, + (int32_t)num_heads_kv, (int32_t)head_dim})); + v_proj_.emplace(tf_.zeros({(int32_t)batch, (int32_t)seq_len, + (int32_t)num_heads_kv, (int32_t)head_dim})); + k_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)max_seq_len, + (int32_t)num_heads_kv, (int32_t)head_dim})); + v_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)max_seq_len, + (int32_t)num_heads_kv, (int32_t)head_dim})); + output_.emplace(tf_.zeros({(int32_t)batch, (int32_t)seq_len, + (int32_t)num_heads_q, (int32_t)head_dim})); + } + + fill_random(*q_, gen); + fill_random(*k_proj_, gen); + fill_random(*v_proj_, gen); + fill_random(*k_cache_, gen); + fill_random(*v_cache_, gen); + + update_output_.emplace(tf_.zeros({1})); + start_pos_ = start_pos; + is_transposed_ = is_transposed; + } + + void TearDown(benchmark::State&) override { + q_.reset(); + k_proj_.reset(); + v_proj_.reset(); + k_cache_.reset(); + v_cache_.reset(); + output_.reset(); + update_output_.reset(); + } + + TensorFactory tf_; + std::optional q_; + std::optional k_proj_; + std::optional v_proj_; + std::optional k_cache_; + std::optional v_cache_; + std::optional output_; + std::optional update_output_; + int64_t start_pos_ = 0; + bool is_transposed_ = false; +}; + +// Benchmark combined update_cache + custom_sdpa +BENCHMARK_DEFINE_F(CombinedBenchFixture, CombinedUpdateSDPA) +(benchmark::State& state) { + for (auto _ : state) { + KernelRuntimeContext ctx{}; + torch::executor::native::update_cache_out( + ctx, *k_proj_, *k_cache_, start_pos_, is_transposed_, *update_output_); + torch::executor::native::update_cache_out( + ctx, *v_proj_, *v_cache_, start_pos_, is_transposed_, *update_output_); + torch::executor::native::custom_sdpa_out( + ctx, + *q_, + *k_cache_, + *v_cache_, + start_pos_, + std::nullopt, // attn_mask + 0.0, // dropout_p + true, // is_causal + std::nullopt, // scale + is_transposed_, + *output_); + } +} + +/* + * Benchmark configurations modeled after Llama 3 8B (GQA: 32 q heads, 8 kv + * heads, head_dim=128). We test decode (seq_len=1) and prefill scenarios at + * various cache fill levels, comparing standard vs transposed layout. + */ + +// --- custom_sdpa: Decode (seq_len=1) --- +// Args: {batch, Hq, Hkv, D, MaxS, StartPos, SeqLen, Transposed} +BENCHMARK_REGISTER_F(SDPABenchFixture, CustomSDPA) + // Standard layout decode at various cache positions + ->Args({1, 32, 8, 128, 2048, 0, 1, 0}) + ->Args({1, 32, 8, 128, 2048, 64, 1, 0}) + ->Args({1, 32, 8, 128, 2048, 256, 1, 0}) + ->Args({1, 32, 8, 128, 2048, 512, 1, 0}) + ->Args({1, 32, 8, 128, 2048, 1024, 1, 0}) + // Transposed layout decode at same positions + ->Args({1, 32, 8, 128, 2048, 0, 1, 1}) + ->Args({1, 32, 8, 128, 2048, 64, 1, 1}) + ->Args({1, 32, 8, 128, 2048, 256, 1, 1}) + ->Args({1, 32, 8, 128, 2048, 512, 1, 1}) + ->Args({1, 32, 8, 128, 2048, 1024, 1, 1}) + // Standard layout prefill + ->Args({1, 32, 8, 128, 2048, 0, 128, 0}) + ->Args({1, 32, 8, 128, 2048, 0, 512, 0}) + // Transposed layout prefill + ->Args({1, 32, 8, 128, 2048, 0, 128, 1}) + ->Args({1, 32, 8, 128, 2048, 0, 512, 1}) + // Llama 2 style (32 heads, no GQA) + ->Args({1, 32, 32, 128, 2048, 256, 1, 0}) + ->Args({1, 32, 32, 128, 2048, 256, 1, 1}) + ->ArgNames( + {"B", "Hq", "Hkv", "D", "MaxS", "StartPos", "SeqLen", "Trans"}); + +// --- update_cache --- +// Args: {batch, H, D, MaxS, StartPos, SeqLen, Transposed} +BENCHMARK_REGISTER_F(UpdateCacheBenchFixture, UpdateCache) + // Decode (seq_len=1) + ->Args({1, 8, 128, 2048, 0, 1, 0}) + ->Args({1, 8, 128, 2048, 256, 1, 0}) + ->Args({1, 8, 128, 2048, 1024, 1, 0}) + ->Args({1, 8, 128, 2048, 0, 1, 1}) + ->Args({1, 8, 128, 2048, 256, 1, 1}) + ->Args({1, 8, 128, 2048, 1024, 1, 1}) + // Prefill + ->Args({1, 8, 128, 2048, 0, 128, 0}) + ->Args({1, 8, 128, 2048, 0, 512, 0}) + ->Args({1, 8, 128, 2048, 0, 128, 1}) + ->Args({1, 8, 128, 2048, 0, 512, 1}) + ->ArgNames({"B", "H", "D", "MaxS", "StartPos", "SeqLen", "Trans"}); + +// --- Combined: update_cache + custom_sdpa --- +// Args: {batch, Hq, Hkv, D, MaxS, StartPos, SeqLen, Transposed} +BENCHMARK_REGISTER_F(CombinedBenchFixture, CombinedUpdateSDPA) + // Decode at various positions + ->Args({1, 32, 8, 128, 2048, 0, 1, 0}) + ->Args({1, 32, 8, 128, 2048, 256, 1, 0}) + ->Args({1, 32, 8, 128, 2048, 512, 1, 0}) + ->Args({1, 32, 8, 128, 2048, 1024, 1, 0}) + ->Args({1, 32, 8, 128, 2048, 0, 1, 1}) + ->Args({1, 32, 8, 128, 2048, 256, 1, 1}) + ->Args({1, 32, 8, 128, 2048, 512, 1, 1}) + ->Args({1, 32, 8, 128, 2048, 1024, 1, 1}) + // Prefill + ->Args({1, 32, 8, 128, 2048, 0, 128, 0}) + ->Args({1, 32, 8, 128, 2048, 0, 128, 1}) + ->ArgNames( + {"B", "Hq", "Hkv", "D", "MaxS", "StartPos", "SeqLen", "Trans"}); + +int main(int argc, char** argv) { + benchmark::Initialize(&argc, argv); + benchmark::RunSpecifiedBenchmarks(); + benchmark::Shutdown(); + return 0; +} diff --git a/extension/llm/custom_ops/bench_transposed_cache.py b/extension/llm/custom_ops/bench_transposed_cache.py new file mode 100644 index 00000000000..a8cd8d051d5 --- /dev/null +++ b/extension/llm/custom_ops/bench_transposed_cache.py @@ -0,0 +1,343 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +""" +Benchmark to compare performance of transposed vs standard KV cache layout +for custom_sdpa and update_cache ops. + +Standard layout: [Batch, Seq, Heads, HeadDim] (is_seq_dim_2=False) +Transposed layout: [Batch, Heads, Seq, HeadDim] (is_seq_dim_2=True) + +The hypothesis is that transposed cache may improve GEMM performance in +custom_sdpa because: + - In attn_score @ V: V's stride along the S_kv dimension changes from + H*D (strided) to D (contiguous), improving memory access patterns. + - In Q @ K^T: K's stride similarly improves from H*D to D. +""" + +import argparse +import time +from typing import Dict, List, Tuple + +import torch + +from executorch.extension.llm.custom_ops import custom_ops # noqa + + +def benchmark_fn(fn, warmup: int = 10, iterations: int = 100) -> float: + """Run fn for warmup iterations, then measure average time over iterations.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() if torch.cuda.is_available() else None + + start = time.perf_counter() + for _ in range(iterations): + fn() + torch.cuda.synchronize() if torch.cuda.is_available() else None + end = time.perf_counter() + + return (end - start) / iterations + + +# Model configurations: (num_heads_kv, num_heads_q, head_dim, max_seq_len) +MODEL_CONFIGS = { + "llama3_8b": (8, 32, 128, 2048), + "llama3_70b": (8, 64, 128, 2048), + "llama2_7b": (32, 32, 128, 2048), + "small": (4, 8, 64, 512), +} + + +def bench_custom_sdpa( + batch_size: int, + num_heads_kv: int, + num_heads_q: int, + head_dim: int, + max_seq_len: int, + start_pos: int, + seq_len: int, + warmup: int = 10, + iterations: int = 100, +) -> Dict[str, float]: + """ + Benchmark custom_sdpa with both cache layouts. + + Returns dict with times for standard and transposed layouts. + """ + results = {} + + # Standard layout: [B, S, H, D] + q_std = torch.rand(batch_size, seq_len, num_heads_q, head_dim) + k_cache_std = torch.rand(batch_size, max_seq_len, num_heads_kv, head_dim) + v_cache_std = torch.rand(batch_size, max_seq_len, num_heads_kv, head_dim) + + def run_std(): + torch.ops.llama.custom_sdpa( + q_std, k_cache_std, v_cache_std, + start_pos, None, 0.0, True, None, False, + ) + + results["standard"] = benchmark_fn(run_std, warmup, iterations) + + # Transposed layout: [B, H, S, D] + q_trans = q_std.transpose(1, 2).contiguous() + k_cache_trans = k_cache_std.transpose(1, 2).contiguous() + v_cache_trans = v_cache_std.transpose(1, 2).contiguous() + + def run_trans(): + torch.ops.llama.custom_sdpa( + q_trans, k_cache_trans, v_cache_trans, + start_pos, None, 0.0, True, None, True, + ) + + results["transposed"] = benchmark_fn(run_trans, warmup, iterations) + + return results + + +def bench_update_cache( + batch_size: int, + num_heads_kv: int, + head_dim: int, + max_seq_len: int, + start_pos: int, + seq_len: int, + warmup: int = 10, + iterations: int = 100, +) -> Dict[str, float]: + """ + Benchmark update_cache with both cache layouts. + + Returns dict with times for standard and transposed layouts. + """ + results = {} + + # Standard layout: [B, S, H, D] + value_std = torch.rand(batch_size, seq_len, num_heads_kv, head_dim) + cache_std = torch.zeros(batch_size, max_seq_len, num_heads_kv, head_dim) + + def run_std(): + torch.ops.llama.update_cache(value_std, cache_std, start_pos, False) + + results["standard"] = benchmark_fn(run_std, warmup, iterations) + + # Transposed layout: [B, H, S, D] + value_trans = value_std.transpose(1, 2).contiguous() + cache_trans = cache_std.transpose(1, 2).contiguous() + + def run_trans(): + torch.ops.llama.update_cache(value_trans, cache_trans, start_pos, True) + + results["transposed"] = benchmark_fn(run_trans, warmup, iterations) + + return results + + +def bench_combined_update_and_sdpa( + batch_size: int, + num_heads_kv: int, + num_heads_q: int, + head_dim: int, + max_seq_len: int, + start_pos: int, + seq_len: int, + warmup: int = 10, + iterations: int = 100, +) -> Dict[str, float]: + """ + Benchmark combined update_cache + custom_sdpa to simulate a full attention + step, which is the real-world usage pattern. + """ + results = {} + + # Standard layout + q_std = torch.rand(batch_size, seq_len, num_heads_q, head_dim) + k_proj_std = torch.rand(batch_size, seq_len, num_heads_kv, head_dim) + v_proj_std = torch.rand(batch_size, seq_len, num_heads_kv, head_dim) + k_cache_std = torch.rand(batch_size, max_seq_len, num_heads_kv, head_dim) + v_cache_std = torch.rand(batch_size, max_seq_len, num_heads_kv, head_dim) + + def run_std(): + torch.ops.llama.update_cache(k_proj_std, k_cache_std, start_pos, False) + torch.ops.llama.update_cache(v_proj_std, v_cache_std, start_pos, False) + torch.ops.llama.custom_sdpa( + q_std, k_cache_std, v_cache_std, + start_pos, None, 0.0, True, None, False, + ) + + results["standard"] = benchmark_fn(run_std, warmup, iterations) + + # Transposed layout + q_trans = q_std.transpose(1, 2).contiguous() + k_proj_trans = k_proj_std.transpose(1, 2).contiguous() + v_proj_trans = v_proj_std.transpose(1, 2).contiguous() + k_cache_trans = k_cache_std.transpose(1, 2).contiguous() + v_cache_trans = v_cache_std.transpose(1, 2).contiguous() + + def run_trans(): + torch.ops.llama.update_cache(k_proj_trans, k_cache_trans, start_pos, True) + torch.ops.llama.update_cache(v_proj_trans, v_cache_trans, start_pos, True) + torch.ops.llama.custom_sdpa( + q_trans, k_cache_trans, v_cache_trans, + start_pos, None, 0.0, True, None, True, + ) + + results["transposed"] = benchmark_fn(run_trans, warmup, iterations) + + return results + + +def format_results( + label: str, + results: Dict[str, float], +) -> str: + """Format benchmark results into a readable string.""" + std_us = results["standard"] * 1e6 + trans_us = results["transposed"] * 1e6 + speedup = results["standard"] / results["transposed"] + return ( + f" {label:40s} std={std_us:10.1f} us trans={trans_us:10.1f} us " + f"speedup={speedup:.3f}x" + ) + + +def run_benchmarks( + config_name: str = "llama3_8b", + batch_size: int = 1, + warmup: int = 20, + iterations: int = 200, + num_threads: int = 1, +): + """Run all benchmarks for a given model configuration.""" + if config_name not in MODEL_CONFIGS: + raise ValueError( + f"Unknown config: {config_name}. " + f"Available: {list(MODEL_CONFIGS.keys())}" + ) + + num_heads_kv, num_heads_q, head_dim, max_seq_len = MODEL_CONFIGS[config_name] + + # Set thread count to isolate from OMP variability + torch.set_num_threads(num_threads) + + print(f"\n{'=' * 90}") + print(f"Config: {config_name} B={batch_size} H_kv={num_heads_kv} " + f"H_q={num_heads_q} D={head_dim} max_S={max_seq_len} " + f"threads={num_threads}") + print(f"Warmup={warmup} Iterations={iterations}") + print(f"{'=' * 90}") + + # Decode phase (seq_len=1) at various cache positions + print("\n--- custom_sdpa: Decode (seq_len=1) ---") + for start_pos in [0, 64, 256, 512, 1024]: + if start_pos >= max_seq_len: + continue + results = bench_custom_sdpa( + batch_size, num_heads_kv, num_heads_q, head_dim, + max_seq_len, start_pos, seq_len=1, + warmup=warmup, iterations=iterations, + ) + print(format_results(f"start_pos={start_pos}", results)) + + # Prefill phase (various seq_len) at start_pos=0 + print("\n--- custom_sdpa: Prefill (start_pos=0) ---") + for seq_len in [32, 64, 128, 256, 512]: + if seq_len >= max_seq_len: + continue + results = bench_custom_sdpa( + batch_size, num_heads_kv, num_heads_q, head_dim, + max_seq_len, start_pos=0, seq_len=seq_len, + warmup=warmup, iterations=iterations, + ) + print(format_results(f"seq_len={seq_len}", results)) + + # update_cache: Decode (seq_len=1) + print("\n--- update_cache: Decode (seq_len=1) ---") + for start_pos in [0, 64, 256, 512, 1024]: + if start_pos >= max_seq_len: + continue + results = bench_update_cache( + batch_size, num_heads_kv, head_dim, + max_seq_len, start_pos, seq_len=1, + warmup=warmup, iterations=iterations, + ) + print(format_results(f"start_pos={start_pos}", results)) + + # update_cache: Prefill + print("\n--- update_cache: Prefill (start_pos=0) ---") + for seq_len in [32, 64, 128, 256, 512]: + if seq_len >= max_seq_len: + continue + results = bench_update_cache( + batch_size, num_heads_kv, head_dim, + max_seq_len, start_pos=0, seq_len=seq_len, + warmup=warmup, iterations=iterations, + ) + print(format_results(f"seq_len={seq_len}", results)) + + # Combined: update_cache + custom_sdpa (realistic attention step) + print("\n--- Combined (update_cache + custom_sdpa): Decode ---") + for start_pos in [0, 64, 256, 512, 1024]: + if start_pos >= max_seq_len: + continue + results = bench_combined_update_and_sdpa( + batch_size, num_heads_kv, num_heads_q, head_dim, + max_seq_len, start_pos, seq_len=1, + warmup=warmup, iterations=iterations, + ) + print(format_results(f"start_pos={start_pos}", results)) + + print("\n--- Combined (update_cache + custom_sdpa): Prefill ---") + for seq_len in [32, 128]: + if seq_len >= max_seq_len: + continue + results = bench_combined_update_and_sdpa( + batch_size, num_heads_kv, num_heads_q, head_dim, + max_seq_len, start_pos=0, seq_len=seq_len, + warmup=warmup, iterations=iterations, + ) + print(format_results(f"seq_len={seq_len}", results)) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark transposed vs standard KV cache layout" + ) + parser.add_argument( + "--config", type=str, default="llama3_8b", + choices=list(MODEL_CONFIGS.keys()), + help="Model configuration to benchmark", + ) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--iterations", type=int, default=200) + parser.add_argument( + "--num-threads", type=int, default=1, + help="Number of threads for torch operations", + ) + parser.add_argument( + "--all-configs", action="store_true", + help="Run benchmarks for all model configurations", + ) + args = parser.parse_args() + + if args.all_configs: + for config_name in MODEL_CONFIGS: + run_benchmarks( + config_name, args.batch_size, + args.warmup, args.iterations, args.num_threads, + ) + else: + run_benchmarks( + args.config, args.batch_size, + args.warmup, args.iterations, args.num_threads, + ) + + +if __name__ == "__main__": + main()