1- // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
1+ /*
2+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3+ * All rights reserved.
4+ *
5+ * This source code is licensed under the BSD-style license found in the
6+ * LICENSE file in the root directory of this source tree.
7+ */
28
39/*
410 * Benchmark for SDPA (scaled dot-product attention) implementations.
@@ -49,7 +55,8 @@ void fill_random(Tensor& t, std::mt19937& gen) {
4955// NOTE: executorch::cpublas::gemm uses COLUMN-MAJOR (Fortran) convention
5056// internally (CblasColMajor). Our data is row-major. The conversion is:
5157// row-major C[M,N] = A[M,K] * trans(B)[K,N] (B stored as [N,K])
52- // becomes col-major: gemm(Trans, NoTrans, N, M, K, a, B, ldb, A, lda, b, C, ldc)
58+ // becomes col-major: gemm(Trans, NoTrans, N, M, K, a, B, ldb, A, lda, b, C,
59+ // ldc)
5360// where lda/ldb/ldc are the row-major strides (row strides = col-major ld).
5461void run_standard_sdpa (
5562 const float * q_data,
@@ -104,10 +111,19 @@ void run_standard_sdpa(
104111 // Row-major: scores[qSeqLen,kvSize] = Q[qSeqLen,D] @ K^T[D,kvSize]
105112 // Col-major: gemm(Trans, NoTrans, kvSize, qSeqLen, D, ...)
106113 executorch::cpublas::gemm (
107- TransposeType::Transpose, TransposeType::NoTranspose,
108- kvSize, q_seq_len, D,
109- 1 .0f , k_ptr, ldk, q_ptr, ldq,
110- 0 .0f , scores, kvSize);
114+ TransposeType::Transpose,
115+ TransposeType::NoTranspose,
116+ kvSize,
117+ q_seq_len,
118+ D,
119+ 1 .0f ,
120+ k_ptr,
121+ ldk,
122+ q_ptr,
123+ ldq,
124+ 0 .0f ,
125+ scores,
126+ kvSize);
111127
112128 // Scale, causal mask, and softmax per query row
113129 for (int64_t qi = 0 ; qi < q_seq_len; ++qi) {
@@ -139,10 +155,19 @@ void run_standard_sdpa(
139155 // Row-major: output[qSeqLen,D] = scores[qSeqLen,kvSize] @ V[kvSize,D]
140156 // Col-major: gemm(NoTrans, NoTrans, D, qSeqLen, kvSize, ...)
141157 executorch::cpublas::gemm (
142- TransposeType::NoTranspose, TransposeType::NoTranspose,
143- D, q_seq_len, kvSize,
144- 1 .0f , v_ptr, ldv, scores, kvSize,
145- 0 .0f , out_ptr, ldo);
158+ TransposeType::NoTranspose,
159+ TransposeType::NoTranspose,
160+ D,
161+ q_seq_len,
162+ kvSize,
163+ 1 .0f ,
164+ v_ptr,
165+ ldv,
166+ scores,
167+ kvSize,
168+ 0 .0f ,
169+ out_ptr,
170+ ldo);
146171 }
147172 });
148173}
@@ -174,8 +199,8 @@ bool validate_config(
174199 std::mt19937 gen (42 );
175200
176201 // Standard [B, S, H, D] layout
177- Tensor q = tf. zeros (
178- {(int32_t )batch, (int32_t )q_seq_len, (int32_t )Hq, (int32_t )D});
202+ Tensor q =
203+ tf. zeros ( {(int32_t )batch, (int32_t )q_seq_len, (int32_t )Hq, (int32_t )D});
179204 Tensor k = tf.zeros (
180205 {(int32_t )batch, (int32_t )max_seq_len, (int32_t )Hkv, (int32_t )D});
181206 Tensor v = tf.zeros (
@@ -186,17 +211,15 @@ bool validate_config(
186211 fill_random (v, gen);
187212
188213 // Reference: ET custom_sdpa_out (10-param signature, standard layout)
189- Tensor out_ref = tf. zeros (
190- {(int32_t )batch, (int32_t )q_seq_len, (int32_t )Hq, (int32_t )D});
214+ Tensor out_ref =
215+ tf. zeros ( {(int32_t )batch, (int32_t )q_seq_len, (int32_t )Hq, (int32_t )D});
191216 KernelRuntimeContext ctx{};
192217 torch::executor::native::custom_sdpa_out (
193- ctx, q, k, v, start_pos,
194- std::nullopt , 0.0 , true , std::nullopt ,
195- out_ref);
218+ ctx, q, k, v, start_pos, std::nullopt , 0.0 , true , std::nullopt , out_ref);
196219
197220 // Test: GEMM-based standard SDPA
198- Tensor out_test = tf. zeros (
199- {(int32_t )batch, (int32_t )q_seq_len, (int32_t )Hq, (int32_t )D});
221+ Tensor out_test =
222+ tf. zeros ( {(int32_t )batch, (int32_t )q_seq_len, (int32_t )Hq, (int32_t )D});
200223 int64_t kvSize = start_pos + q_seq_len;
201224 std::vector<float > scores_buf (batch * Hq * q_seq_len * kvSize);
202225 run_standard_sdpa (
@@ -205,7 +228,13 @@ bool validate_config(
205228 v.const_data_ptr <float >(),
206229 out_test.mutable_data_ptr <float >(),
207230 scores_buf.data (),
208- batch, Hq, Hkv, D, max_seq_len, start_pos, q_seq_len,
231+ batch,
232+ Hq,
233+ Hkv,
234+ D,
235+ max_seq_len,
236+ start_pos,
237+ q_seq_len,
209238 false /* is_transposed */ );
210239
211240 float diff = max_abs_diff (out_ref, out_test);
@@ -215,16 +244,29 @@ bool validate_config(
215244 stderr,
216245 " FAIL: StandardSDPA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
217246 " max_abs_diff=%.6e > atol=%.6e\n " ,
218- mode, (long )batch, (long )Hq, (long )Hkv, (long )D,
219- (long )start_pos, (long )q_seq_len, diff, atol);
247+ mode,
248+ (long )batch,
249+ (long )Hq,
250+ (long )Hkv,
251+ (long )D,
252+ (long )start_pos,
253+ (long )q_seq_len,
254+ diff,
255+ atol);
220256 return false ;
221257 }
222258 fprintf (
223259 stderr,
224260 " PASS: StandardSDPA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
225261 " max_abs_diff=%.6e\n " ,
226- mode, (long )batch, (long )Hq, (long )Hkv, (long )D,
227- (long )start_pos, (long )q_seq_len, diff);
262+ mode,
263+ (long )batch,
264+ (long )Hq,
265+ (long )Hkv,
266+ (long )D,
267+ (long )start_pos,
268+ (long )q_seq_len,
269+ diff);
228270
229271 return true ;
230272}
@@ -277,14 +319,26 @@ class SDPABenchFixture : public benchmark::Fixture {
277319 std::mt19937 gen (42 );
278320
279321 // Standard [B, S, H, D] layout
280- q_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )q_seq_len,
281- (int32_t )num_heads_q, (int32_t )head_dim}));
282- k_cache_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )max_seq_len,
283- (int32_t )num_heads_kv, (int32_t )head_dim}));
284- v_cache_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )max_seq_len,
285- (int32_t )num_heads_kv, (int32_t )head_dim}));
286- output_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )q_seq_len,
287- (int32_t )num_heads_q, (int32_t )head_dim}));
322+ q_.emplace (tf_.zeros (
323+ {(int32_t )batch,
324+ (int32_t )q_seq_len,
325+ (int32_t )num_heads_q,
326+ (int32_t )head_dim}));
327+ k_cache_.emplace (tf_.zeros (
328+ {(int32_t )batch,
329+ (int32_t )max_seq_len,
330+ (int32_t )num_heads_kv,
331+ (int32_t )head_dim}));
332+ v_cache_.emplace (tf_.zeros (
333+ {(int32_t )batch,
334+ (int32_t )max_seq_len,
335+ (int32_t )num_heads_kv,
336+ (int32_t )head_dim}));
337+ output_.emplace (tf_.zeros (
338+ {(int32_t )batch,
339+ (int32_t )q_seq_len,
340+ (int32_t )num_heads_q,
341+ (int32_t )head_dim}));
288342
289343 fill_random (*q_, gen);
290344 fill_random (*k_cache_, gen);
@@ -320,8 +374,8 @@ BENCHMARK_DEFINE_F(SDPABenchFixture, CustomSDPA)
320374 *v_cache_,
321375 start_pos_,
322376 std::nullopt , // attn_mask
323- 0.0 , // dropout_p
324- true , // is_causal
377+ 0.0 , // dropout_p
378+ true , // is_causal
325379 std::nullopt , // scale
326380 *output_);
327381 }
@@ -347,24 +401,48 @@ class StandardSDPABenchFixture : public benchmark::Fixture {
347401
348402 if (is_transposed) {
349403 // [B, H, S, D]
350- q_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )num_heads_q,
351- (int32_t )q_seq_len, (int32_t )head_dim}));
352- k_cache_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )num_heads_kv,
353- (int32_t )max_seq_len, (int32_t )head_dim}));
354- v_cache_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )num_heads_kv,
355- (int32_t )max_seq_len, (int32_t )head_dim}));
356- output_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )num_heads_q,
357- (int32_t )q_seq_len, (int32_t )head_dim}));
404+ q_.emplace (tf_.zeros (
405+ {(int32_t )batch,
406+ (int32_t )num_heads_q,
407+ (int32_t )q_seq_len,
408+ (int32_t )head_dim}));
409+ k_cache_.emplace (tf_.zeros (
410+ {(int32_t )batch,
411+ (int32_t )num_heads_kv,
412+ (int32_t )max_seq_len,
413+ (int32_t )head_dim}));
414+ v_cache_.emplace (tf_.zeros (
415+ {(int32_t )batch,
416+ (int32_t )num_heads_kv,
417+ (int32_t )max_seq_len,
418+ (int32_t )head_dim}));
419+ output_.emplace (tf_.zeros (
420+ {(int32_t )batch,
421+ (int32_t )num_heads_q,
422+ (int32_t )q_seq_len,
423+ (int32_t )head_dim}));
358424 } else {
359425 // [B, S, H, D]
360- q_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )q_seq_len,
361- (int32_t )num_heads_q, (int32_t )head_dim}));
362- k_cache_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )max_seq_len,
363- (int32_t )num_heads_kv, (int32_t )head_dim}));
364- v_cache_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )max_seq_len,
365- (int32_t )num_heads_kv, (int32_t )head_dim}));
366- output_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )q_seq_len,
367- (int32_t )num_heads_q, (int32_t )head_dim}));
426+ q_.emplace (tf_.zeros (
427+ {(int32_t )batch,
428+ (int32_t )q_seq_len,
429+ (int32_t )num_heads_q,
430+ (int32_t )head_dim}));
431+ k_cache_.emplace (tf_.zeros (
432+ {(int32_t )batch,
433+ (int32_t )max_seq_len,
434+ (int32_t )num_heads_kv,
435+ (int32_t )head_dim}));
436+ v_cache_.emplace (tf_.zeros (
437+ {(int32_t )batch,
438+ (int32_t )max_seq_len,
439+ (int32_t )num_heads_kv,
440+ (int32_t )head_dim}));
441+ output_.emplace (tf_.zeros (
442+ {(int32_t )batch,
443+ (int32_t )q_seq_len,
444+ (int32_t )num_heads_q,
445+ (int32_t )head_dim}));
368446 }
369447
370448 fill_random (*q_, gen);
@@ -423,9 +501,19 @@ BENCHMARK_DEFINE_F(StandardSDPABenchFixture, StandardSDPA)
423501
424502 for (auto _ : state) {
425503 run_standard_sdpa (
426- q_data, k_data, v_data, out_data, scores_buf_.data (),
427- batch_, num_heads_q_, num_heads_kv_, head_dim_,
428- max_seq_len_, start_pos_, q_seq_len_, is_transposed_);
504+ q_data,
505+ k_data,
506+ v_data,
507+ out_data,
508+ scores_buf_.data (),
509+ batch_,
510+ num_heads_q_,
511+ num_heads_kv_,
512+ head_dim_,
513+ max_seq_len_,
514+ start_pos_,
515+ q_seq_len_,
516+ is_transposed_);
429517 }
430518}
431519
@@ -475,8 +563,7 @@ BENCHMARK_REGISTER_F(StandardSDPABenchFixture, StandardSDPA)
475563 // Llama 2 style (32 heads, no GQA)
476564 ->Args({1 , 32 , 32 , 128 , 2048 , 256 , 1 , 0 })
477565 ->Args({1 , 32 , 32 , 128 , 2048 , 256 , 1 , 1 })
478- ->ArgNames(
479- {" B" , " Hq" , " Hkv" , " D" , " MaxS" , " StartPos" , " SeqLen" , " Trans" });
566+ ->ArgNames({" B" , " Hq" , " Hkv" , " D" , " MaxS" , " StartPos" , " SeqLen" , " Trans" });
480567
481568} // namespace
482569
0 commit comments