@@ -172,17 +172,145 @@ void run_standard_sdpa(
172172 });
173173}
174174
175+ // ONNX Runtime GQA-style SDPA, faithfully ported from
176+ // onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h.
177+ // Differences from run_standard_sdpa:
178+ // 1. Scale in GEMM alpha (no separate scaling pass)
179+ // 2. Scores buffer padded to max_seq_len cols (ONNX's present_buffer_seq_len)
180+ // 3. Causal mask: zero out future positions, softmax on valid window only
181+ // 4. Output in [B, S, Hq, D] with stride Hq*D (ONNX's interleaved BNSH->BSNH)
182+ //
183+ // When is_transposed=true, inputs are [B,H,S,D]; output is [B,S,Hq,D].
184+ // When is_transposed=false, inputs are [B,S,H,D]; output is [B,S,Hq,D].
185+ // Output is always [B, S, Hq, D] to match ONNX's actual output format.
186+ void run_onnx_gqa_sdpa (
187+ const float * q_data,
188+ const float * k_data,
189+ const float * v_data,
190+ float * out_data, // always [B, q_seq_len, Hq, D]
191+ float * scores_buf, // must hold batch*Hq*q_seq_len*max_seq_len floats
192+ int64_t batch,
193+ int64_t Hq,
194+ int64_t Hkv,
195+ int64_t D,
196+ int64_t max_seq_len,
197+ int64_t start_pos,
198+ int64_t q_seq_len,
199+ bool is_transposed) {
200+ using executorch::cpublas::TransposeType;
201+
202+ const int64_t total_seqlen = start_pos + q_seq_len;
203+ const float alpha = 1 .0f / std::sqrt (static_cast <float >(D));
204+ const int64_t heads_per_group = Hq / Hkv;
205+ const int64_t hidden_size = Hq * D; // output row stride (ONNX convention)
206+
207+ // Input strides depend on layout
208+ const int64_t ldq = is_transposed ? D : Hq * D;
209+ const int64_t ldk = is_transposed ? D : Hkv * D;
210+ const int64_t ldv = is_transposed ? D : Hkv * D;
211+ // Output is always [B, S, Hq, D] so ldo = Hq * D = hidden_size
212+ const int64_t ldo = hidden_size;
213+
214+ torch::executor::parallel_for (
215+ 0 , batch * Hq, 1 , [&](int64_t begin, int64_t end) {
216+ for (int64_t idx = begin; idx < end; ++idx) {
217+ const int64_t b = idx / Hq;
218+ const int64_t h = idx % Hq;
219+ const int64_t kv_h = h / heads_per_group;
220+
221+ const float * q_ptr;
222+ const float * k_ptr;
223+ const float * v_ptr;
224+ if (is_transposed) {
225+ q_ptr = q_data + (b * Hq + h) * q_seq_len * D;
226+ k_ptr = k_data + (b * Hkv + kv_h) * max_seq_len * D;
227+ v_ptr = v_data + (b * Hkv + kv_h) * max_seq_len * D;
228+ } else {
229+ q_ptr = q_data + b * q_seq_len * Hq * D + h * D;
230+ k_ptr = k_data + b * max_seq_len * Hkv * D + kv_h * D;
231+ v_ptr = v_data + b * max_seq_len * Hkv * D + kv_h * D;
232+ }
233+ // Output always [B, S, Hq, D]: head h writes at stride hidden_size
234+ float * out_ptr = out_data + b * q_seq_len * hidden_size + h * D;
235+
236+ // Scores padded to max_seq_len columns (ONNX convention)
237+ float * scores = scores_buf + idx * q_seq_len * max_seq_len;
238+
239+ // GEMM 1: Q @ K^T with scale in alpha
240+ executorch::cpublas::gemm (
241+ TransposeType::Transpose,
242+ TransposeType::NoTranspose,
243+ total_seqlen,
244+ q_seq_len,
245+ D,
246+ alpha,
247+ k_ptr,
248+ ldk,
249+ q_ptr,
250+ ldq,
251+ 0 .0f ,
252+ scores,
253+ max_seq_len);
254+
255+ // Causal mask + narrow softmax (ONNX style):
256+ // Zero future positions, softmax only on valid [0, causal_len).
257+ for (int64_t qi = 0 ; qi < q_seq_len; ++qi) {
258+ float * row = scores + qi * max_seq_len;
259+ const int64_t causal_len =
260+ std::min (start_pos + qi + 1 , total_seqlen);
261+
262+ for (int64_t j = causal_len; j < total_seqlen; ++j) {
263+ row[j] = 0 .0f ;
264+ }
265+
266+ float max_val = row[0 ];
267+ for (int64_t j = 1 ; j < causal_len; ++j) {
268+ max_val = std::max (max_val, row[j]);
269+ }
270+ float sum = 0 .0f ;
271+ for (int64_t j = 0 ; j < causal_len; ++j) {
272+ row[j] = std::exp (row[j] - max_val);
273+ sum += row[j];
274+ }
275+ const float inv_sum = 1 .0f / sum;
276+ for (int64_t j = 0 ; j < causal_len; ++j) {
277+ row[j] *= inv_sum;
278+ }
279+ }
280+
281+ // GEMM 2: scores @ V -> output
282+ executorch::cpublas::gemm (
283+ TransposeType::NoTranspose,
284+ TransposeType::NoTranspose,
285+ D,
286+ q_seq_len,
287+ total_seqlen,
288+ 1 .0f ,
289+ v_ptr,
290+ ldv,
291+ scores,
292+ max_seq_len,
293+ 0 .0f ,
294+ out_ptr,
295+ ldo);
296+ }
297+ });
298+ }
299+
175300// Return max |a - b| across all elements.
176- float max_abs_diff (const Tensor& a, const Tensor& b) {
177- const float * a_data = a.const_data_ptr <float >();
178- const float * b_data = b.const_data_ptr <float >();
301+ float max_abs_diff (const float * a, const float * b, int64_t n) {
179302 float d = 0 .0f ;
180- for (int64_t i = 0 ; i < a. numel () ; ++i) {
181- d = std::max (d, std::abs (a_data [i] - b_data [i]));
303+ for (int64_t i = 0 ; i < n ; ++i) {
304+ d = std::max (d, std::abs (a [i] - b [i]));
182305 }
183306 return d;
184307}
185308
309+ float max_abs_diff (const Tensor& a, const Tensor& b) {
310+ return max_abs_diff (
311+ a.const_data_ptr <float >(), b.const_data_ptr <float >(), a.numel ());
312+ }
313+
186314// Validate a single config: run StandardSDPA and custom_sdpa_out on the same
187315// inputs, check outputs match within tolerance. Returns false on mismatch.
188316// Only tests standard [B,S,H,D] layout (is_transposed=false).
@@ -268,6 +396,65 @@ bool validate_config(
268396 (long )q_seq_len,
269397 diff);
270398
399+ // Also validate ONNX GQA variant. Output is always [B, S, Hq, D].
400+ // Since we only test standard [B,S,H,D] layout, out_ref is already
401+ // [B,S,Hq,D] — just copy directly to ref_bshd (no transpose needed).
402+ Tensor out_onnx =
403+ tf.zeros ({(int32_t )batch, (int32_t )q_seq_len, (int32_t )Hq, (int32_t )D});
404+ std::vector<float > onnx_scores_buf (batch * Hq * q_seq_len * max_seq_len);
405+ run_onnx_gqa_sdpa (
406+ q.const_data_ptr <float >(),
407+ k.const_data_ptr <float >(),
408+ v.const_data_ptr <float >(),
409+ out_onnx.mutable_data_ptr <float >(),
410+ onnx_scores_buf.data (),
411+ batch,
412+ Hq,
413+ Hkv,
414+ D,
415+ max_seq_len,
416+ start_pos,
417+ q_seq_len,
418+ false /* is_transposed */ );
419+
420+ // out_ref is already [B, S, Hq, D] (standard layout), compare directly
421+ std::vector<float > ref_bshd (batch * q_seq_len * Hq * D);
422+ const float * ref_ptr = out_ref.const_data_ptr <float >();
423+ std::copy (ref_ptr, ref_ptr + batch * q_seq_len * Hq * D, ref_bshd.data ());
424+
425+ float onnx_diff = max_abs_diff (
426+ out_onnx.const_data_ptr <float >(),
427+ ref_bshd.data (),
428+ batch * q_seq_len * Hq * D);
429+ if (onnx_diff > atol) {
430+ fprintf (
431+ stderr,
432+ " FAIL: OnnxGQA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
433+ " max_abs_diff=%.6e > atol=%.6e\n " ,
434+ mode,
435+ (long )batch,
436+ (long )Hq,
437+ (long )Hkv,
438+ (long )D,
439+ (long )start_pos,
440+ (long )q_seq_len,
441+ onnx_diff,
442+ atol);
443+ return false ;
444+ }
445+ fprintf (
446+ stderr,
447+ " PASS: OnnxGQA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
448+ " max_abs_diff=%.6e\n " ,
449+ mode,
450+ (long )batch,
451+ (long )Hq,
452+ (long )Hkv,
453+ (long )D,
454+ (long )start_pos,
455+ (long )q_seq_len,
456+ onnx_diff);
457+
271458 return true ;
272459}
273460
@@ -517,6 +704,132 @@ BENCHMARK_DEFINE_F(StandardSDPABenchFixture, StandardSDPA)
517704 }
518705}
519706
707+ // ONNX Runtime GQA-style benchmark. Faithfully matches the algorithm from
708+ // gqa_attention_base.h: scale-in-alpha, padded scores buffer, narrow softmax,
709+ // and output in [B, S, Hq, D] with stride Hq*D.
710+ class OnnxGQABenchFixture : public benchmark ::Fixture {
711+ public:
712+ // Args: {batch, num_heads_q, num_heads_kv, head_dim, max_seq_len, start_pos,
713+ // query_seq_len, is_transposed}
714+ void SetUp (benchmark::State& state) override {
715+ int64_t batch = state.range (0 );
716+ int64_t num_heads_q = state.range (1 );
717+ int64_t num_heads_kv = state.range (2 );
718+ int64_t head_dim = state.range (3 );
719+ int64_t max_seq_len = state.range (4 );
720+ int64_t start_pos = state.range (5 );
721+ int64_t q_seq_len = state.range (6 );
722+ bool is_transposed = state.range (7 ) != 0 ;
723+
724+ std::mt19937 gen (42 );
725+
726+ if (is_transposed) {
727+ q_.emplace (tf_.zeros (
728+ {(int32_t )batch,
729+ (int32_t )num_heads_q,
730+ (int32_t )q_seq_len,
731+ (int32_t )head_dim}));
732+ k_cache_.emplace (tf_.zeros (
733+ {(int32_t )batch,
734+ (int32_t )num_heads_kv,
735+ (int32_t )max_seq_len,
736+ (int32_t )head_dim}));
737+ v_cache_.emplace (tf_.zeros (
738+ {(int32_t )batch,
739+ (int32_t )num_heads_kv,
740+ (int32_t )max_seq_len,
741+ (int32_t )head_dim}));
742+ } else {
743+ q_.emplace (tf_.zeros (
744+ {(int32_t )batch,
745+ (int32_t )q_seq_len,
746+ (int32_t )num_heads_q,
747+ (int32_t )head_dim}));
748+ k_cache_.emplace (tf_.zeros (
749+ {(int32_t )batch,
750+ (int32_t )max_seq_len,
751+ (int32_t )num_heads_kv,
752+ (int32_t )head_dim}));
753+ v_cache_.emplace (tf_.zeros (
754+ {(int32_t )batch,
755+ (int32_t )max_seq_len,
756+ (int32_t )num_heads_kv,
757+ (int32_t )head_dim}));
758+ }
759+ // Output always [B, S, Hq, D] (ONNX convention)
760+ output_.emplace (tf_.zeros (
761+ {(int32_t )batch,
762+ (int32_t )q_seq_len,
763+ (int32_t )num_heads_q,
764+ (int32_t )head_dim}));
765+
766+ fill_random (*q_, gen);
767+ fill_random (*k_cache_, gen);
768+ fill_random (*v_cache_, gen);
769+
770+ batch_ = batch;
771+ num_heads_q_ = num_heads_q;
772+ num_heads_kv_ = num_heads_kv;
773+ head_dim_ = head_dim;
774+ max_seq_len_ = max_seq_len;
775+ start_pos_ = start_pos;
776+ q_seq_len_ = q_seq_len;
777+ is_transposed_ = is_transposed;
778+
779+ // Scores buffer padded to max_seq_len columns (ONNX convention)
780+ int64_t total_units = batch * num_heads_q;
781+ scores_buf_.resize (total_units * q_seq_len * max_seq_len);
782+ }
783+
784+ void TearDown (benchmark::State&) override {
785+ q_.reset ();
786+ k_cache_.reset ();
787+ v_cache_.reset ();
788+ output_.reset ();
789+ scores_buf_.clear ();
790+ }
791+
792+ TensorFactory<ScalarType::Float> tf_;
793+ std::optional<Tensor> q_;
794+ std::optional<Tensor> k_cache_;
795+ std::optional<Tensor> v_cache_;
796+ std::optional<Tensor> output_;
797+ std::vector<float > scores_buf_;
798+ int64_t batch_ = 0 ;
799+ int64_t num_heads_q_ = 0 ;
800+ int64_t num_heads_kv_ = 0 ;
801+ int64_t head_dim_ = 0 ;
802+ int64_t max_seq_len_ = 0 ;
803+ int64_t start_pos_ = 0 ;
804+ int64_t q_seq_len_ = 0 ;
805+ bool is_transposed_ = false ;
806+ };
807+
808+ BENCHMARK_DEFINE_F (OnnxGQABenchFixture, OnnxGQA)
809+ (benchmark::State& state) {
810+ const float * q_data = q_->const_data_ptr <float >();
811+ const float * k_data = k_cache_->const_data_ptr <float >();
812+ const float * v_data = v_cache_->const_data_ptr <float >();
813+ float * out_data = output_->mutable_data_ptr <float >();
814+
815+ for (auto _ : state) {
816+ run_onnx_gqa_sdpa (
817+ q_data,
818+ k_data,
819+ v_data,
820+ out_data,
821+ scores_buf_.data (),
822+ batch_,
823+ num_heads_q_,
824+ num_heads_kv_,
825+ head_dim_,
826+ max_seq_len_,
827+ start_pos_,
828+ q_seq_len_,
829+ is_transposed_);
830+ }
831+ }
832+
520833/*
521834 * Benchmark configurations modeled after Llama 3 8B (GQA: 32 q heads, 8 kv
522835 * heads, head_dim=128). We test decode (seq_len=1) and prefill scenarios at
@@ -565,6 +878,33 @@ BENCHMARK_REGISTER_F(StandardSDPABenchFixture, StandardSDPA)
565878 ->Args({1 , 32 , 32 , 128 , 2048 , 256 , 1 , 1 })
566879 ->ArgNames({" B" , " Hq" , " Hkv" , " D" , " MaxS" , " StartPos" , " SeqLen" , " Trans" });
567880
881+ // --- ONNX Runtime GQA-style SDPA ---
882+ // Same configs as StandardSDPA. Differences: scale-in-alpha, padded scores
883+ // buffer (ld=MaxS), narrow softmax, output in [B,S,Hq,D] with stride Hq*D.
884+ BENCHMARK_REGISTER_F (OnnxGQABenchFixture, OnnxGQA)
885+ // Standard layout decode at various cache positions
886+ ->Args ({1 , 32 , 8 , 128 , 2048 , 0 , 1 , 0 })
887+ ->Args({1 , 32 , 8 , 128 , 2048 , 64 , 1 , 0 })
888+ ->Args({1 , 32 , 8 , 128 , 2048 , 256 , 1 , 0 })
889+ ->Args({1 , 32 , 8 , 128 , 2048 , 512 , 1 , 0 })
890+ ->Args({1 , 32 , 8 , 128 , 2048 , 1024 , 1 , 0 })
891+ // Transposed layout decode at same positions
892+ ->Args({1 , 32 , 8 , 128 , 2048 , 0 , 1 , 1 })
893+ ->Args({1 , 32 , 8 , 128 , 2048 , 64 , 1 , 1 })
894+ ->Args({1 , 32 , 8 , 128 , 2048 , 256 , 1 , 1 })
895+ ->Args({1 , 32 , 8 , 128 , 2048 , 512 , 1 , 1 })
896+ ->Args({1 , 32 , 8 , 128 , 2048 , 1024 , 1 , 1 })
897+ // Standard layout prefill
898+ ->Args({1 , 32 , 8 , 128 , 2048 , 0 , 128 , 0 })
899+ ->Args({1 , 32 , 8 , 128 , 2048 , 0 , 512 , 0 })
900+ // Transposed layout prefill
901+ ->Args({1 , 32 , 8 , 128 , 2048 , 0 , 128 , 1 })
902+ ->Args({1 , 32 , 8 , 128 , 2048 , 0 , 512 , 1 })
903+ // Llama 2 style (32 heads, no GQA)
904+ ->Args({1 , 32 , 32 , 128 , 2048 , 256 , 1 , 0 })
905+ ->Args({1 , 32 , 32 , 128 , 2048 , 256 , 1 , 1 })
906+ ->ArgNames({" B" , " Hq" , " Hkv" , " D" , " MaxS" , " StartPos" , " SeqLen" , " Trans" });
907+
568908} // namespace
569909
570910int main (int argc, char ** argv) {
0 commit comments