Skip to content

Commit 4a828d4

Browse files
committed
Add GEMM-based standard SDPA benchmark
Add bench_sdpa.cpp with a standalone GEMM-based SDPA implementation (run_standard_sdpa) alongside ExecuTorch's tiled flash attention (custom_sdpa_out) for comparative benchmarking. The standalone SDPA uses full GEMM per head with 3-pass softmax and supports both [B,S,H,D] and [B,H,S,D] layouts via BLAS leading dimension parameters, allowing isolation of algorithm vs layout effects. Includes validation tests that verify the GEMM-based implementation matches custom_sdpa_out within tolerance. Differential Revision: [D96044313](https://our.internmc.facebook.com/intern/diff/D96044313/) ghstack-source-id: 361224784 Pull Request resolved: #18646
1 parent 44e344c commit 4a828d4

File tree

2 files changed

+503
-0
lines changed

2 files changed

+503
-0
lines changed

extension/llm/custom_ops/BUCK

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
2+
load("@fbcode_macros//build_defs:cpp_benchmark.bzl", "cpp_benchmark")
23
oncall("executorch")
34
# Any targets that should be shared between fbcode and xplat must be defined in
45
# targets.bzl. This file can contain xplat-only targets.
@@ -89,3 +90,16 @@ fbcode_target(_kind = runtime.python_test,
8990
"//caffe2:torch",
9091
],
9192
)
93+
94+
fbcode_target(_kind = cpp_benchmark,
95+
name = "bench_sdpa",
96+
srcs = ["bench_sdpa.cpp"],
97+
deps = [
98+
"fbsource//third-party/benchmark:benchmark",
99+
"//executorch/extension/llm/custom_ops:custom_ops_mkl_noomp",
100+
"//executorch/extension/threadpool:threadpool",
101+
"//executorch/kernels/optimized:libblas",
102+
"//executorch/runtime/core/exec_aten:lib",
103+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
104+
],
105+
)

0 commit comments

Comments
 (0)