Skip to content

Commit 91f9213

Browse files
committed
Update on "Add GEMM-based standard SDPA benchmark"
Add bench_sdpa.cpp with a standalone GEMM-based SDPA implementation (run_standard_sdpa) alongside ExecuTorch's tiled flash attention (custom_sdpa_out) for comparative benchmarking. The standalone SDPA uses full GEMM per head with 3-pass softmax and supports both [B,S,H,D] and [B,H,S,D] layouts via BLAS leading dimension parameters, allowing isolation of algorithm vs layout effects. Includes validation tests that verify the GEMM-based implementation matches custom_sdpa_out within tolerance. Differential Revision: [D96044313](https://our.internmc.facebook.com/intern/diff/D96044313/) [ghstack-poisoned]
2 parents 5cbe9f5 + dc0d80e commit 91f9213

1 file changed

Lines changed: 142 additions & 55 deletions

File tree

extension/llm/custom_ops/bench_sdpa.cpp

Lines changed: 142 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
28

39
/*
410
* Benchmark for SDPA (scaled dot-product attention) implementations.
@@ -49,7 +55,8 @@ void fill_random(Tensor& t, std::mt19937& gen) {
4955
// NOTE: executorch::cpublas::gemm uses COLUMN-MAJOR (Fortran) convention
5056
// internally (CblasColMajor). Our data is row-major. The conversion is:
5157
// row-major C[M,N] = A[M,K] * trans(B)[K,N] (B stored as [N,K])
52-
// becomes col-major: gemm(Trans, NoTrans, N, M, K, a, B, ldb, A, lda, b, C, ldc)
58+
// becomes col-major: gemm(Trans, NoTrans, N, M, K, a, B, ldb, A, lda, b, C,
59+
// ldc)
5360
// where lda/ldb/ldc are the row-major strides (row strides = col-major ld).
5461
void run_standard_sdpa(
5562
const float* q_data,
@@ -104,10 +111,19 @@ void run_standard_sdpa(
104111
// Row-major: scores[qSeqLen,kvSize] = Q[qSeqLen,D] @ K^T[D,kvSize]
105112
// Col-major: gemm(Trans, NoTrans, kvSize, qSeqLen, D, ...)
106113
executorch::cpublas::gemm(
107-
TransposeType::Transpose, TransposeType::NoTranspose,
108-
kvSize, q_seq_len, D,
109-
1.0f, k_ptr, ldk, q_ptr, ldq,
110-
0.0f, scores, kvSize);
114+
TransposeType::Transpose,
115+
TransposeType::NoTranspose,
116+
kvSize,
117+
q_seq_len,
118+
D,
119+
1.0f,
120+
k_ptr,
121+
ldk,
122+
q_ptr,
123+
ldq,
124+
0.0f,
125+
scores,
126+
kvSize);
111127

112128
// Scale, causal mask, and softmax per query row
113129
for (int64_t qi = 0; qi < q_seq_len; ++qi) {
@@ -139,10 +155,19 @@ void run_standard_sdpa(
139155
// Row-major: output[qSeqLen,D] = scores[qSeqLen,kvSize] @ V[kvSize,D]
140156
// Col-major: gemm(NoTrans, NoTrans, D, qSeqLen, kvSize, ...)
141157
executorch::cpublas::gemm(
142-
TransposeType::NoTranspose, TransposeType::NoTranspose,
143-
D, q_seq_len, kvSize,
144-
1.0f, v_ptr, ldv, scores, kvSize,
145-
0.0f, out_ptr, ldo);
158+
TransposeType::NoTranspose,
159+
TransposeType::NoTranspose,
160+
D,
161+
q_seq_len,
162+
kvSize,
163+
1.0f,
164+
v_ptr,
165+
ldv,
166+
scores,
167+
kvSize,
168+
0.0f,
169+
out_ptr,
170+
ldo);
146171
}
147172
});
148173
}
@@ -174,8 +199,8 @@ bool validate_config(
174199
std::mt19937 gen(42);
175200

176201
// Standard [B, S, H, D] layout
177-
Tensor q = tf.zeros(
178-
{(int32_t)batch, (int32_t)q_seq_len, (int32_t)Hq, (int32_t)D});
202+
Tensor q =
203+
tf.zeros({(int32_t)batch, (int32_t)q_seq_len, (int32_t)Hq, (int32_t)D});
179204
Tensor k = tf.zeros(
180205
{(int32_t)batch, (int32_t)max_seq_len, (int32_t)Hkv, (int32_t)D});
181206
Tensor v = tf.zeros(
@@ -186,17 +211,15 @@ bool validate_config(
186211
fill_random(v, gen);
187212

188213
// Reference: ET custom_sdpa_out (10-param signature, standard layout)
189-
Tensor out_ref = tf.zeros(
190-
{(int32_t)batch, (int32_t)q_seq_len, (int32_t)Hq, (int32_t)D});
214+
Tensor out_ref =
215+
tf.zeros({(int32_t)batch, (int32_t)q_seq_len, (int32_t)Hq, (int32_t)D});
191216
KernelRuntimeContext ctx{};
192217
torch::executor::native::custom_sdpa_out(
193-
ctx, q, k, v, start_pos,
194-
std::nullopt, 0.0, true, std::nullopt,
195-
out_ref);
218+
ctx, q, k, v, start_pos, std::nullopt, 0.0, true, std::nullopt, out_ref);
196219

197220
// Test: GEMM-based standard SDPA
198-
Tensor out_test = tf.zeros(
199-
{(int32_t)batch, (int32_t)q_seq_len, (int32_t)Hq, (int32_t)D});
221+
Tensor out_test =
222+
tf.zeros({(int32_t)batch, (int32_t)q_seq_len, (int32_t)Hq, (int32_t)D});
200223
int64_t kvSize = start_pos + q_seq_len;
201224
std::vector<float> scores_buf(batch * Hq * q_seq_len * kvSize);
202225
run_standard_sdpa(
@@ -205,7 +228,13 @@ bool validate_config(
205228
v.const_data_ptr<float>(),
206229
out_test.mutable_data_ptr<float>(),
207230
scores_buf.data(),
208-
batch, Hq, Hkv, D, max_seq_len, start_pos, q_seq_len,
231+
batch,
232+
Hq,
233+
Hkv,
234+
D,
235+
max_seq_len,
236+
start_pos,
237+
q_seq_len,
209238
false /* is_transposed */);
210239

211240
float diff = max_abs_diff(out_ref, out_test);
@@ -215,16 +244,29 @@ bool validate_config(
215244
stderr,
216245
"FAIL: StandardSDPA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
217246
"max_abs_diff=%.6e > atol=%.6e\n",
218-
mode, (long)batch, (long)Hq, (long)Hkv, (long)D,
219-
(long)start_pos, (long)q_seq_len, diff, atol);
247+
mode,
248+
(long)batch,
249+
(long)Hq,
250+
(long)Hkv,
251+
(long)D,
252+
(long)start_pos,
253+
(long)q_seq_len,
254+
diff,
255+
atol);
220256
return false;
221257
}
222258
fprintf(
223259
stderr,
224260
"PASS: StandardSDPA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
225261
"max_abs_diff=%.6e\n",
226-
mode, (long)batch, (long)Hq, (long)Hkv, (long)D,
227-
(long)start_pos, (long)q_seq_len, diff);
262+
mode,
263+
(long)batch,
264+
(long)Hq,
265+
(long)Hkv,
266+
(long)D,
267+
(long)start_pos,
268+
(long)q_seq_len,
269+
diff);
228270

229271
return true;
230272
}
@@ -277,14 +319,26 @@ class SDPABenchFixture : public benchmark::Fixture {
277319
std::mt19937 gen(42);
278320

279321
// Standard [B, S, H, D] layout
280-
q_.emplace(tf_.zeros({(int32_t)batch, (int32_t)q_seq_len,
281-
(int32_t)num_heads_q, (int32_t)head_dim}));
282-
k_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)max_seq_len,
283-
(int32_t)num_heads_kv, (int32_t)head_dim}));
284-
v_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)max_seq_len,
285-
(int32_t)num_heads_kv, (int32_t)head_dim}));
286-
output_.emplace(tf_.zeros({(int32_t)batch, (int32_t)q_seq_len,
287-
(int32_t)num_heads_q, (int32_t)head_dim}));
322+
q_.emplace(tf_.zeros(
323+
{(int32_t)batch,
324+
(int32_t)q_seq_len,
325+
(int32_t)num_heads_q,
326+
(int32_t)head_dim}));
327+
k_cache_.emplace(tf_.zeros(
328+
{(int32_t)batch,
329+
(int32_t)max_seq_len,
330+
(int32_t)num_heads_kv,
331+
(int32_t)head_dim}));
332+
v_cache_.emplace(tf_.zeros(
333+
{(int32_t)batch,
334+
(int32_t)max_seq_len,
335+
(int32_t)num_heads_kv,
336+
(int32_t)head_dim}));
337+
output_.emplace(tf_.zeros(
338+
{(int32_t)batch,
339+
(int32_t)q_seq_len,
340+
(int32_t)num_heads_q,
341+
(int32_t)head_dim}));
288342

289343
fill_random(*q_, gen);
290344
fill_random(*k_cache_, gen);
@@ -320,8 +374,8 @@ BENCHMARK_DEFINE_F(SDPABenchFixture, CustomSDPA)
320374
*v_cache_,
321375
start_pos_,
322376
std::nullopt, // attn_mask
323-
0.0, // dropout_p
324-
true, // is_causal
377+
0.0, // dropout_p
378+
true, // is_causal
325379
std::nullopt, // scale
326380
*output_);
327381
}
@@ -347,24 +401,48 @@ class StandardSDPABenchFixture : public benchmark::Fixture {
347401

348402
if (is_transposed) {
349403
// [B, H, S, D]
350-
q_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_q,
351-
(int32_t)q_seq_len, (int32_t)head_dim}));
352-
k_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_kv,
353-
(int32_t)max_seq_len, (int32_t)head_dim}));
354-
v_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_kv,
355-
(int32_t)max_seq_len, (int32_t)head_dim}));
356-
output_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_q,
357-
(int32_t)q_seq_len, (int32_t)head_dim}));
404+
q_.emplace(tf_.zeros(
405+
{(int32_t)batch,
406+
(int32_t)num_heads_q,
407+
(int32_t)q_seq_len,
408+
(int32_t)head_dim}));
409+
k_cache_.emplace(tf_.zeros(
410+
{(int32_t)batch,
411+
(int32_t)num_heads_kv,
412+
(int32_t)max_seq_len,
413+
(int32_t)head_dim}));
414+
v_cache_.emplace(tf_.zeros(
415+
{(int32_t)batch,
416+
(int32_t)num_heads_kv,
417+
(int32_t)max_seq_len,
418+
(int32_t)head_dim}));
419+
output_.emplace(tf_.zeros(
420+
{(int32_t)batch,
421+
(int32_t)num_heads_q,
422+
(int32_t)q_seq_len,
423+
(int32_t)head_dim}));
358424
} else {
359425
// [B, S, H, D]
360-
q_.emplace(tf_.zeros({(int32_t)batch, (int32_t)q_seq_len,
361-
(int32_t)num_heads_q, (int32_t)head_dim}));
362-
k_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)max_seq_len,
363-
(int32_t)num_heads_kv, (int32_t)head_dim}));
364-
v_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)max_seq_len,
365-
(int32_t)num_heads_kv, (int32_t)head_dim}));
366-
output_.emplace(tf_.zeros({(int32_t)batch, (int32_t)q_seq_len,
367-
(int32_t)num_heads_q, (int32_t)head_dim}));
426+
q_.emplace(tf_.zeros(
427+
{(int32_t)batch,
428+
(int32_t)q_seq_len,
429+
(int32_t)num_heads_q,
430+
(int32_t)head_dim}));
431+
k_cache_.emplace(tf_.zeros(
432+
{(int32_t)batch,
433+
(int32_t)max_seq_len,
434+
(int32_t)num_heads_kv,
435+
(int32_t)head_dim}));
436+
v_cache_.emplace(tf_.zeros(
437+
{(int32_t)batch,
438+
(int32_t)max_seq_len,
439+
(int32_t)num_heads_kv,
440+
(int32_t)head_dim}));
441+
output_.emplace(tf_.zeros(
442+
{(int32_t)batch,
443+
(int32_t)q_seq_len,
444+
(int32_t)num_heads_q,
445+
(int32_t)head_dim}));
368446
}
369447

370448
fill_random(*q_, gen);
@@ -423,9 +501,19 @@ BENCHMARK_DEFINE_F(StandardSDPABenchFixture, StandardSDPA)
423501

424502
for (auto _ : state) {
425503
run_standard_sdpa(
426-
q_data, k_data, v_data, out_data, scores_buf_.data(),
427-
batch_, num_heads_q_, num_heads_kv_, head_dim_,
428-
max_seq_len_, start_pos_, q_seq_len_, is_transposed_);
504+
q_data,
505+
k_data,
506+
v_data,
507+
out_data,
508+
scores_buf_.data(),
509+
batch_,
510+
num_heads_q_,
511+
num_heads_kv_,
512+
head_dim_,
513+
max_seq_len_,
514+
start_pos_,
515+
q_seq_len_,
516+
is_transposed_);
429517
}
430518
}
431519

@@ -475,8 +563,7 @@ BENCHMARK_REGISTER_F(StandardSDPABenchFixture, StandardSDPA)
475563
// Llama 2 style (32 heads, no GQA)
476564
->Args({1, 32, 32, 128, 2048, 256, 1, 0})
477565
->Args({1, 32, 32, 128, 2048, 256, 1, 1})
478-
->ArgNames(
479-
{"B", "Hq", "Hkv", "D", "MaxS", "StartPos", "SeqLen", "Trans"});
566+
->ArgNames({"B", "Hq", "Hkv", "D", "MaxS", "StartPos", "SeqLen", "Trans"});
480567

481568
} // namespace
482569

0 commit comments

Comments
 (0)