Skip to content

Commit e968c4a

Browse files
committed
Add GEMM-based standard SDPA benchmark and fix custom_sdpa_out signature
Adds a standalone GEMM-based (non-tiled) attention benchmark alongside the existing ET flash attention benchmark, allowing direct comparison of the two algorithms. The standard SDPA uses cblas_sgemm for Q@K^T and scores@V with 3-pass softmax, matching the approach used by ONNX Runtime's GQA operator. Both transposed [B,H,S,D] and standard [B,S,H,D] cache layouts are supported via BLAS leading dimension parameter. Validation tests run before benchmarks to ensure correctness against ET's custom_sdpa_out. Also fixes broken custom_sdpa_out calls to match the new 3-bool signature (is_seq_dim_2, is_k_seq_dim_2, is_v_seq_dim_2). Authored with Claude. Differential Revision: [D99677686](https://our.internmc.facebook.com/intern/diff/D99677686/) [ghstack-poisoned]
1 parent 45468e1 commit e968c4a

2 files changed

Lines changed: 381 additions & 0 deletions

File tree

extension/llm/custom_ops/BUCK

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ cpp_benchmark(
139139
deps = [
140140
"fbsource//third-party/benchmark:benchmark",
141141
"//executorch/extension/llm/custom_ops:custom_ops_mkl_noomp",
142+
"//executorch/extension/threadpool:threadpool",
143+
"//executorch/kernels/optimized:libblas",
142144
"//executorch/runtime/core/exec_aten:lib",
143145
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
144146
],

0 commit comments

Comments
 (0)