@@ -147,17 +147,128 @@ void run_standard_sdpa(
147147 });
148148}
149149
150+ // ONNX Runtime GQA-style SDPA, faithfully ported from
151+ // onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h.
152+ // Differences from run_standard_sdpa:
153+ // 1. Scale in GEMM alpha (no separate scaling pass)
154+ // 2. Scores buffer padded to max_seq_len cols (ONNX's present_buffer_seq_len)
155+ // 3. Causal mask: zero out future positions, softmax on valid window only
156+ // 4. Output in [B, S, Hq, D] with stride Hq*D (ONNX's interleaved BNSH->BSNH)
157+ //
158+ // When is_transposed=true, inputs are [B,H,S,D]; output is [B,S,Hq,D].
159+ // When is_transposed=false, inputs are [B,S,H,D]; output is [B,S,Hq,D].
160+ // Output is always [B, S, Hq, D] to match ONNX's actual output format.
161+ void run_onnx_gqa_sdpa (
162+ const float * q_data,
163+ const float * k_data,
164+ const float * v_data,
165+ float * out_data, // always [B, q_seq_len, Hq, D]
166+ float * scores_buf, // must hold batch*Hq*q_seq_len*max_seq_len floats
167+ int64_t batch,
168+ int64_t Hq,
169+ int64_t Hkv,
170+ int64_t D,
171+ int64_t max_seq_len,
172+ int64_t start_pos,
173+ int64_t q_seq_len,
174+ bool is_transposed) {
175+ using executorch::cpublas::TransposeType;
176+
177+ const int64_t total_seqlen = start_pos + q_seq_len;
178+ const float alpha = 1 .0f / std::sqrt (static_cast <float >(D));
179+ const int64_t heads_per_group = Hq / Hkv;
180+ const int64_t hidden_size = Hq * D; // output row stride (ONNX convention)
181+
182+ // Input strides depend on layout
183+ const int64_t ldq = is_transposed ? D : Hq * D;
184+ const int64_t ldk = is_transposed ? D : Hkv * D;
185+ const int64_t ldv = is_transposed ? D : Hkv * D;
186+ // Output is always [B, S, Hq, D] so ldo = Hq * D = hidden_size
187+ const int64_t ldo = hidden_size;
188+
189+ torch::executor::parallel_for (
190+ 0 , batch * Hq, 1 , [&](int64_t begin, int64_t end) {
191+ for (int64_t idx = begin; idx < end; ++idx) {
192+ const int64_t b = idx / Hq;
193+ const int64_t h = idx % Hq;
194+ const int64_t kv_h = h / heads_per_group;
195+
196+ const float * q_ptr;
197+ const float * k_ptr;
198+ const float * v_ptr;
199+ if (is_transposed) {
200+ q_ptr = q_data + (b * Hq + h) * q_seq_len * D;
201+ k_ptr = k_data + (b * Hkv + kv_h) * max_seq_len * D;
202+ v_ptr = v_data + (b * Hkv + kv_h) * max_seq_len * D;
203+ } else {
204+ q_ptr = q_data + b * q_seq_len * Hq * D + h * D;
205+ k_ptr = k_data + b * max_seq_len * Hkv * D + kv_h * D;
206+ v_ptr = v_data + b * max_seq_len * Hkv * D + kv_h * D;
207+ }
208+ // Output always [B, S, Hq, D]: head h writes at stride hidden_size
209+ float * out_ptr =
210+ out_data + b * q_seq_len * hidden_size + h * D;
211+
212+ // Scores padded to max_seq_len columns (ONNX convention)
213+ float * scores = scores_buf + idx * q_seq_len * max_seq_len;
214+
215+ // GEMM 1: Q @ K^T with scale in alpha
216+ executorch::cpublas::gemm (
217+ TransposeType::Transpose, TransposeType::NoTranspose,
218+ total_seqlen, q_seq_len, D,
219+ alpha, k_ptr, ldk, q_ptr, ldq,
220+ 0 .0f , scores, max_seq_len);
221+
222+ // Causal mask + narrow softmax (ONNX style):
223+ // Zero future positions, softmax only on valid [0, causal_len).
224+ for (int64_t qi = 0 ; qi < q_seq_len; ++qi) {
225+ float * row = scores + qi * max_seq_len;
226+ const int64_t causal_len =
227+ std::min (start_pos + qi + 1 , total_seqlen);
228+
229+ for (int64_t j = causal_len; j < total_seqlen; ++j) {
230+ row[j] = 0 .0f ;
231+ }
232+
233+ float max_val = row[0 ];
234+ for (int64_t j = 1 ; j < causal_len; ++j) {
235+ max_val = std::max (max_val, row[j]);
236+ }
237+ float sum = 0 .0f ;
238+ for (int64_t j = 0 ; j < causal_len; ++j) {
239+ row[j] = std::exp (row[j] - max_val);
240+ sum += row[j];
241+ }
242+ const float inv_sum = 1 .0f / sum;
243+ for (int64_t j = 0 ; j < causal_len; ++j) {
244+ row[j] *= inv_sum;
245+ }
246+ }
247+
248+ // GEMM 2: scores @ V -> output
249+ executorch::cpublas::gemm (
250+ TransposeType::NoTranspose, TransposeType::NoTranspose,
251+ D, q_seq_len, total_seqlen,
252+ 1 .0f , v_ptr, ldv, scores, max_seq_len,
253+ 0 .0f , out_ptr, ldo);
254+ }
255+ });
256+ }
257+
150258// Return max |a - b| across all elements.
151- float max_abs_diff (const Tensor& a, const Tensor& b) {
152- const float * a_data = a.const_data_ptr <float >();
153- const float * b_data = b.const_data_ptr <float >();
259+ float max_abs_diff (const float * a, const float * b, int64_t n) {
154260 float d = 0 .0f ;
155- for (int64_t i = 0 ; i < a. numel () ; ++i) {
156- d = std::max (d, std::abs (a_data [i] - b_data [i]));
261+ for (int64_t i = 0 ; i < n ; ++i) {
262+ d = std::max (d, std::abs (a [i] - b [i]));
157263 }
158264 return d;
159265}
160266
267+ float max_abs_diff (const Tensor& a, const Tensor& b) {
268+ return max_abs_diff (
269+ a.const_data_ptr <float >(), b.const_data_ptr <float >(), a.numel ());
270+ }
271+
161272// Validate a single config: run StandardSDPA and custom_sdpa_out on the same
162273// inputs, check outputs match within tolerance. Returns false on mismatch.
163274// Only tests standard [B,S,H,D] layout (is_transposed=false).
@@ -226,6 +337,45 @@ bool validate_config(
226337 mode, (long )batch, (long )Hq, (long )Hkv, (long )D,
227338 (long )start_pos, (long )q_seq_len, diff);
228339
340+ // Also validate ONNX GQA variant. Output is always [B, S, Hq, D].
341+ // Since we only test standard [B,S,H,D] layout, out_ref is already
342+ // [B,S,Hq,D] — just copy directly to ref_bshd (no transpose needed).
343+ Tensor out_onnx =
344+ tf.zeros ({(int32_t )batch, (int32_t )q_seq_len, (int32_t )Hq, (int32_t )D});
345+ std::vector<float > onnx_scores_buf (batch * Hq * q_seq_len * max_seq_len);
346+ run_onnx_gqa_sdpa (
347+ q.const_data_ptr <float >(),
348+ k.const_data_ptr <float >(),
349+ v.const_data_ptr <float >(),
350+ out_onnx.mutable_data_ptr <float >(),
351+ onnx_scores_buf.data (),
352+ batch, Hq, Hkv, D, max_seq_len, start_pos, q_seq_len,
353+ false /* is_transposed */ );
354+
355+ // out_ref is already [B, S, Hq, D] (standard layout), compare directly
356+ std::vector<float > ref_bshd (batch * q_seq_len * Hq * D);
357+ const float * ref_ptr = out_ref.const_data_ptr <float >();
358+ std::copy (ref_ptr, ref_ptr + batch * q_seq_len * Hq * D, ref_bshd.data ());
359+
360+ float onnx_diff = max_abs_diff (
361+ out_onnx.const_data_ptr <float >(), ref_bshd.data (),
362+ batch * q_seq_len * Hq * D);
363+ if (onnx_diff > atol) {
364+ fprintf (
365+ stderr,
366+ " FAIL: OnnxGQA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
367+ " max_abs_diff=%.6e > atol=%.6e\n " ,
368+ mode, (long )batch, (long )Hq, (long )Hkv, (long )D,
369+ (long )start_pos, (long )q_seq_len, onnx_diff, atol);
370+ return false ;
371+ }
372+ fprintf (
373+ stderr,
374+ " PASS: OnnxGQA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
375+ " max_abs_diff=%.6e\n " ,
376+ mode, (long )batch, (long )Hq, (long )Hkv, (long )D,
377+ (long )start_pos, (long )q_seq_len, onnx_diff);
378+
229379 return true ;
230380}
231381
@@ -429,6 +579,101 @@ BENCHMARK_DEFINE_F(StandardSDPABenchFixture, StandardSDPA)
429579 }
430580}
431581
582+ // ONNX Runtime GQA-style benchmark. Faithfully matches the algorithm from
583+ // gqa_attention_base.h: scale-in-alpha, padded scores buffer, narrow softmax,
584+ // and output in [B, S, Hq, D] with stride Hq*D.
585+ class OnnxGQABenchFixture : public benchmark ::Fixture {
586+ public:
587+ // Args: {batch, num_heads_q, num_heads_kv, head_dim, max_seq_len, start_pos,
588+ // query_seq_len, is_transposed}
589+ void SetUp (benchmark::State& state) override {
590+ int64_t batch = state.range (0 );
591+ int64_t num_heads_q = state.range (1 );
592+ int64_t num_heads_kv = state.range (2 );
593+ int64_t head_dim = state.range (3 );
594+ int64_t max_seq_len = state.range (4 );
595+ int64_t start_pos = state.range (5 );
596+ int64_t q_seq_len = state.range (6 );
597+ bool is_transposed = state.range (7 ) != 0 ;
598+
599+ std::mt19937 gen (42 );
600+
601+ if (is_transposed) {
602+ q_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )num_heads_q,
603+ (int32_t )q_seq_len, (int32_t )head_dim}));
604+ k_cache_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )num_heads_kv,
605+ (int32_t )max_seq_len, (int32_t )head_dim}));
606+ v_cache_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )num_heads_kv,
607+ (int32_t )max_seq_len, (int32_t )head_dim}));
608+ } else {
609+ q_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )q_seq_len,
610+ (int32_t )num_heads_q, (int32_t )head_dim}));
611+ k_cache_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )max_seq_len,
612+ (int32_t )num_heads_kv, (int32_t )head_dim}));
613+ v_cache_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )max_seq_len,
614+ (int32_t )num_heads_kv, (int32_t )head_dim}));
615+ }
616+ // Output always [B, S, Hq, D] (ONNX convention)
617+ output_.emplace (tf_.zeros ({(int32_t )batch, (int32_t )q_seq_len,
618+ (int32_t )num_heads_q, (int32_t )head_dim}));
619+
620+ fill_random (*q_, gen);
621+ fill_random (*k_cache_, gen);
622+ fill_random (*v_cache_, gen);
623+
624+ batch_ = batch;
625+ num_heads_q_ = num_heads_q;
626+ num_heads_kv_ = num_heads_kv;
627+ head_dim_ = head_dim;
628+ max_seq_len_ = max_seq_len;
629+ start_pos_ = start_pos;
630+ q_seq_len_ = q_seq_len;
631+ is_transposed_ = is_transposed;
632+
633+ // Scores buffer padded to max_seq_len columns (ONNX convention)
634+ int64_t total_units = batch * num_heads_q;
635+ scores_buf_.resize (total_units * q_seq_len * max_seq_len);
636+ }
637+
638+ void TearDown (benchmark::State&) override {
639+ q_.reset ();
640+ k_cache_.reset ();
641+ v_cache_.reset ();
642+ output_.reset ();
643+ scores_buf_.clear ();
644+ }
645+
646+ TensorFactory<ScalarType::Float> tf_;
647+ std::optional<Tensor> q_;
648+ std::optional<Tensor> k_cache_;
649+ std::optional<Tensor> v_cache_;
650+ std::optional<Tensor> output_;
651+ std::vector<float > scores_buf_;
652+ int64_t batch_ = 0 ;
653+ int64_t num_heads_q_ = 0 ;
654+ int64_t num_heads_kv_ = 0 ;
655+ int64_t head_dim_ = 0 ;
656+ int64_t max_seq_len_ = 0 ;
657+ int64_t start_pos_ = 0 ;
658+ int64_t q_seq_len_ = 0 ;
659+ bool is_transposed_ = false ;
660+ };
661+
662+ BENCHMARK_DEFINE_F (OnnxGQABenchFixture, OnnxGQA)
663+ (benchmark::State& state) {
664+ const float * q_data = q_->const_data_ptr <float >();
665+ const float * k_data = k_cache_->const_data_ptr <float >();
666+ const float * v_data = v_cache_->const_data_ptr <float >();
667+ float * out_data = output_->mutable_data_ptr <float >();
668+
669+ for (auto _ : state) {
670+ run_onnx_gqa_sdpa (
671+ q_data, k_data, v_data, out_data, scores_buf_.data (),
672+ batch_, num_heads_q_, num_heads_kv_, head_dim_,
673+ max_seq_len_, start_pos_, q_seq_len_, is_transposed_);
674+ }
675+ }
676+
432677/*
433678 * Benchmark configurations modeled after Llama 3 8B (GQA: 32 q heads, 8 kv
434679 * heads, head_dim=128). We test decode (seq_len=1) and prefill scenarios at
@@ -478,6 +723,34 @@ BENCHMARK_REGISTER_F(StandardSDPABenchFixture, StandardSDPA)
478723 ->ArgNames(
479724 {" B" , " Hq" , " Hkv" , " D" , " MaxS" , " StartPos" , " SeqLen" , " Trans" });
480725
726+ // --- ONNX Runtime GQA-style SDPA ---
727+ // Same configs as StandardSDPA. Differences: scale-in-alpha, padded scores
728+ // buffer (ld=MaxS), narrow softmax, output in [B,S,Hq,D] with stride Hq*D.
729+ BENCHMARK_REGISTER_F (OnnxGQABenchFixture, OnnxGQA)
730+ // Standard layout decode at various cache positions
731+ ->Args ({1 , 32 , 8 , 128 , 2048 , 0 , 1 , 0 })
732+ ->Args({1 , 32 , 8 , 128 , 2048 , 64 , 1 , 0 })
733+ ->Args({1 , 32 , 8 , 128 , 2048 , 256 , 1 , 0 })
734+ ->Args({1 , 32 , 8 , 128 , 2048 , 512 , 1 , 0 })
735+ ->Args({1 , 32 , 8 , 128 , 2048 , 1024 , 1 , 0 })
736+ // Transposed layout decode at same positions
737+ ->Args({1 , 32 , 8 , 128 , 2048 , 0 , 1 , 1 })
738+ ->Args({1 , 32 , 8 , 128 , 2048 , 64 , 1 , 1 })
739+ ->Args({1 , 32 , 8 , 128 , 2048 , 256 , 1 , 1 })
740+ ->Args({1 , 32 , 8 , 128 , 2048 , 512 , 1 , 1 })
741+ ->Args({1 , 32 , 8 , 128 , 2048 , 1024 , 1 , 1 })
742+ // Standard layout prefill
743+ ->Args({1 , 32 , 8 , 128 , 2048 , 0 , 128 , 0 })
744+ ->Args({1 , 32 , 8 , 128 , 2048 , 0 , 512 , 0 })
745+ // Transposed layout prefill
746+ ->Args({1 , 32 , 8 , 128 , 2048 , 0 , 128 , 1 })
747+ ->Args({1 , 32 , 8 , 128 , 2048 , 0 , 512 , 1 })
748+ // Llama 2 style (32 heads, no GQA)
749+ ->Args({1 , 32 , 32 , 128 , 2048 , 256 , 1 , 0 })
750+ ->Args({1 , 32 , 32 , 128 , 2048 , 256 , 1 , 1 })
751+ ->ArgNames(
752+ {" B" , " Hq" , " Hkv" , " D" , " MaxS" , " StartPos" , " SeqLen" , " Trans" });
753+
481754} // namespace
482755
483756int main (int argc, char ** argv) {
0 commit comments