Skip to content

Commit 5cc3cde

Browse files
committed
Add ONNX Runtime GQA-style SDPA benchmark
Pull Request resolved: #18647 Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. ghstack-source-id: 374666318 @exported-using-ghexport Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/)
1 parent 48773c2 commit 5cc3cde

1 file changed

Lines changed: 345 additions & 5 deletions

File tree

extension/llm/custom_ops/bench_sdpa.cpp

Lines changed: 345 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,17 +172,145 @@ void run_standard_sdpa(
172172
});
173173
}
174174

175+
// ONNX Runtime GQA-style SDPA, faithfully ported from
176+
// onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h.
177+
// Differences from run_standard_sdpa:
178+
// 1. Scale in GEMM alpha (no separate scaling pass)
179+
// 2. Scores buffer padded to max_seq_len cols (ONNX's present_buffer_seq_len)
180+
// 3. Causal mask: zero out future positions, softmax on valid window only
181+
// 4. Output in [B, S, Hq, D] with stride Hq*D (ONNX's interleaved BNSH->BSNH)
182+
//
183+
// When is_transposed=true, inputs are [B,H,S,D]; output is [B,S,Hq,D].
184+
// When is_transposed=false, inputs are [B,S,H,D]; output is [B,S,Hq,D].
185+
// Output is always [B, S, Hq, D] to match ONNX's actual output format.
186+
void run_onnx_gqa_sdpa(
187+
const float* q_data,
188+
const float* k_data,
189+
const float* v_data,
190+
float* out_data, // always [B, q_seq_len, Hq, D]
191+
float* scores_buf, // must hold batch*Hq*q_seq_len*max_seq_len floats
192+
int64_t batch,
193+
int64_t Hq,
194+
int64_t Hkv,
195+
int64_t D,
196+
int64_t max_seq_len,
197+
int64_t start_pos,
198+
int64_t q_seq_len,
199+
bool is_transposed) {
200+
using executorch::cpublas::TransposeType;
201+
202+
const int64_t total_seqlen = start_pos + q_seq_len;
203+
const float alpha = 1.0f / std::sqrt(static_cast<float>(D));
204+
const int64_t heads_per_group = Hq / Hkv;
205+
const int64_t hidden_size = Hq * D; // output row stride (ONNX convention)
206+
207+
// Input strides depend on layout
208+
const int64_t ldq = is_transposed ? D : Hq * D;
209+
const int64_t ldk = is_transposed ? D : Hkv * D;
210+
const int64_t ldv = is_transposed ? D : Hkv * D;
211+
// Output is always [B, S, Hq, D] so ldo = Hq * D = hidden_size
212+
const int64_t ldo = hidden_size;
213+
214+
torch::executor::parallel_for(
215+
0, batch * Hq, 1, [&](int64_t begin, int64_t end) {
216+
for (int64_t idx = begin; idx < end; ++idx) {
217+
const int64_t b = idx / Hq;
218+
const int64_t h = idx % Hq;
219+
const int64_t kv_h = h / heads_per_group;
220+
221+
const float* q_ptr;
222+
const float* k_ptr;
223+
const float* v_ptr;
224+
if (is_transposed) {
225+
q_ptr = q_data + (b * Hq + h) * q_seq_len * D;
226+
k_ptr = k_data + (b * Hkv + kv_h) * max_seq_len * D;
227+
v_ptr = v_data + (b * Hkv + kv_h) * max_seq_len * D;
228+
} else {
229+
q_ptr = q_data + b * q_seq_len * Hq * D + h * D;
230+
k_ptr = k_data + b * max_seq_len * Hkv * D + kv_h * D;
231+
v_ptr = v_data + b * max_seq_len * Hkv * D + kv_h * D;
232+
}
233+
// Output always [B, S, Hq, D]: head h writes at stride hidden_size
234+
float* out_ptr = out_data + b * q_seq_len * hidden_size + h * D;
235+
236+
// Scores padded to max_seq_len columns (ONNX convention)
237+
float* scores = scores_buf + idx * q_seq_len * max_seq_len;
238+
239+
// GEMM 1: Q @ K^T with scale in alpha
240+
executorch::cpublas::gemm(
241+
TransposeType::Transpose,
242+
TransposeType::NoTranspose,
243+
total_seqlen,
244+
q_seq_len,
245+
D,
246+
alpha,
247+
k_ptr,
248+
ldk,
249+
q_ptr,
250+
ldq,
251+
0.0f,
252+
scores,
253+
max_seq_len);
254+
255+
// Causal mask + narrow softmax (ONNX style):
256+
// Zero future positions, softmax only on valid [0, causal_len).
257+
for (int64_t qi = 0; qi < q_seq_len; ++qi) {
258+
float* row = scores + qi * max_seq_len;
259+
const int64_t causal_len =
260+
std::min(start_pos + qi + 1, total_seqlen);
261+
262+
for (int64_t j = causal_len; j < total_seqlen; ++j) {
263+
row[j] = 0.0f;
264+
}
265+
266+
float max_val = row[0];
267+
for (int64_t j = 1; j < causal_len; ++j) {
268+
max_val = std::max(max_val, row[j]);
269+
}
270+
float sum = 0.0f;
271+
for (int64_t j = 0; j < causal_len; ++j) {
272+
row[j] = std::exp(row[j] - max_val);
273+
sum += row[j];
274+
}
275+
const float inv_sum = 1.0f / sum;
276+
for (int64_t j = 0; j < causal_len; ++j) {
277+
row[j] *= inv_sum;
278+
}
279+
}
280+
281+
// GEMM 2: scores @ V -> output
282+
executorch::cpublas::gemm(
283+
TransposeType::NoTranspose,
284+
TransposeType::NoTranspose,
285+
D,
286+
q_seq_len,
287+
total_seqlen,
288+
1.0f,
289+
v_ptr,
290+
ldv,
291+
scores,
292+
max_seq_len,
293+
0.0f,
294+
out_ptr,
295+
ldo);
296+
}
297+
});
298+
}
299+
175300
// Return max |a - b| across all elements.
176-
float max_abs_diff(const Tensor& a, const Tensor& b) {
177-
const float* a_data = a.const_data_ptr<float>();
178-
const float* b_data = b.const_data_ptr<float>();
301+
float max_abs_diff(const float* a, const float* b, int64_t n) {
179302
float d = 0.0f;
180-
for (int64_t i = 0; i < a.numel(); ++i) {
181-
d = std::max(d, std::abs(a_data[i] - b_data[i]));
303+
for (int64_t i = 0; i < n; ++i) {
304+
d = std::max(d, std::abs(a[i] - b[i]));
182305
}
183306
return d;
184307
}
185308

309+
float max_abs_diff(const Tensor& a, const Tensor& b) {
310+
return max_abs_diff(
311+
a.const_data_ptr<float>(), b.const_data_ptr<float>(), a.numel());
312+
}
313+
186314
// Validate a single config: run StandardSDPA and custom_sdpa_out on the same
187315
// inputs, check outputs match within tolerance. Returns false on mismatch.
188316
// Only tests standard [B,S,H,D] layout (is_transposed=false).
@@ -268,6 +396,65 @@ bool validate_config(
268396
(long)q_seq_len,
269397
diff);
270398

399+
// Also validate ONNX GQA variant. Output is always [B, S, Hq, D].
400+
// Since we only test standard [B,S,H,D] layout, out_ref is already
401+
// [B,S,Hq,D] — just copy directly to ref_bshd (no transpose needed).
402+
Tensor out_onnx =
403+
tf.zeros({(int32_t)batch, (int32_t)q_seq_len, (int32_t)Hq, (int32_t)D});
404+
std::vector<float> onnx_scores_buf(batch * Hq * q_seq_len * max_seq_len);
405+
run_onnx_gqa_sdpa(
406+
q.const_data_ptr<float>(),
407+
k.const_data_ptr<float>(),
408+
v.const_data_ptr<float>(),
409+
out_onnx.mutable_data_ptr<float>(),
410+
onnx_scores_buf.data(),
411+
batch,
412+
Hq,
413+
Hkv,
414+
D,
415+
max_seq_len,
416+
start_pos,
417+
q_seq_len,
418+
false /* is_transposed */);
419+
420+
// out_ref is already [B, S, Hq, D] (standard layout), compare directly
421+
std::vector<float> ref_bshd(batch * q_seq_len * Hq * D);
422+
const float* ref_ptr = out_ref.const_data_ptr<float>();
423+
std::copy(ref_ptr, ref_ptr + batch * q_seq_len * Hq * D, ref_bshd.data());
424+
425+
float onnx_diff = max_abs_diff(
426+
out_onnx.const_data_ptr<float>(),
427+
ref_bshd.data(),
428+
batch * q_seq_len * Hq * D);
429+
if (onnx_diff > atol) {
430+
fprintf(
431+
stderr,
432+
"FAIL: OnnxGQA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
433+
"max_abs_diff=%.6e > atol=%.6e\n",
434+
mode,
435+
(long)batch,
436+
(long)Hq,
437+
(long)Hkv,
438+
(long)D,
439+
(long)start_pos,
440+
(long)q_seq_len,
441+
onnx_diff,
442+
atol);
443+
return false;
444+
}
445+
fprintf(
446+
stderr,
447+
"PASS: OnnxGQA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
448+
"max_abs_diff=%.6e\n",
449+
mode,
450+
(long)batch,
451+
(long)Hq,
452+
(long)Hkv,
453+
(long)D,
454+
(long)start_pos,
455+
(long)q_seq_len,
456+
onnx_diff);
457+
271458
return true;
272459
}
273460

@@ -517,6 +704,132 @@ BENCHMARK_DEFINE_F(StandardSDPABenchFixture, StandardSDPA)
517704
}
518705
}
519706

707+
// ONNX Runtime GQA-style benchmark. Faithfully matches the algorithm from
708+
// gqa_attention_base.h: scale-in-alpha, padded scores buffer, narrow softmax,
709+
// and output in [B, S, Hq, D] with stride Hq*D.
710+
class OnnxGQABenchFixture : public benchmark::Fixture {
711+
public:
712+
// Args: {batch, num_heads_q, num_heads_kv, head_dim, max_seq_len, start_pos,
713+
// query_seq_len, is_transposed}
714+
void SetUp(benchmark::State& state) override {
715+
int64_t batch = state.range(0);
716+
int64_t num_heads_q = state.range(1);
717+
int64_t num_heads_kv = state.range(2);
718+
int64_t head_dim = state.range(3);
719+
int64_t max_seq_len = state.range(4);
720+
int64_t start_pos = state.range(5);
721+
int64_t q_seq_len = state.range(6);
722+
bool is_transposed = state.range(7) != 0;
723+
724+
std::mt19937 gen(42);
725+
726+
if (is_transposed) {
727+
q_.emplace(tf_.zeros(
728+
{(int32_t)batch,
729+
(int32_t)num_heads_q,
730+
(int32_t)q_seq_len,
731+
(int32_t)head_dim}));
732+
k_cache_.emplace(tf_.zeros(
733+
{(int32_t)batch,
734+
(int32_t)num_heads_kv,
735+
(int32_t)max_seq_len,
736+
(int32_t)head_dim}));
737+
v_cache_.emplace(tf_.zeros(
738+
{(int32_t)batch,
739+
(int32_t)num_heads_kv,
740+
(int32_t)max_seq_len,
741+
(int32_t)head_dim}));
742+
} else {
743+
q_.emplace(tf_.zeros(
744+
{(int32_t)batch,
745+
(int32_t)q_seq_len,
746+
(int32_t)num_heads_q,
747+
(int32_t)head_dim}));
748+
k_cache_.emplace(tf_.zeros(
749+
{(int32_t)batch,
750+
(int32_t)max_seq_len,
751+
(int32_t)num_heads_kv,
752+
(int32_t)head_dim}));
753+
v_cache_.emplace(tf_.zeros(
754+
{(int32_t)batch,
755+
(int32_t)max_seq_len,
756+
(int32_t)num_heads_kv,
757+
(int32_t)head_dim}));
758+
}
759+
// Output always [B, S, Hq, D] (ONNX convention)
760+
output_.emplace(tf_.zeros(
761+
{(int32_t)batch,
762+
(int32_t)q_seq_len,
763+
(int32_t)num_heads_q,
764+
(int32_t)head_dim}));
765+
766+
fill_random(*q_, gen);
767+
fill_random(*k_cache_, gen);
768+
fill_random(*v_cache_, gen);
769+
770+
batch_ = batch;
771+
num_heads_q_ = num_heads_q;
772+
num_heads_kv_ = num_heads_kv;
773+
head_dim_ = head_dim;
774+
max_seq_len_ = max_seq_len;
775+
start_pos_ = start_pos;
776+
q_seq_len_ = q_seq_len;
777+
is_transposed_ = is_transposed;
778+
779+
// Scores buffer padded to max_seq_len columns (ONNX convention)
780+
int64_t total_units = batch * num_heads_q;
781+
scores_buf_.resize(total_units * q_seq_len * max_seq_len);
782+
}
783+
784+
void TearDown(benchmark::State&) override {
785+
q_.reset();
786+
k_cache_.reset();
787+
v_cache_.reset();
788+
output_.reset();
789+
scores_buf_.clear();
790+
}
791+
792+
TensorFactory<ScalarType::Float> tf_;
793+
std::optional<Tensor> q_;
794+
std::optional<Tensor> k_cache_;
795+
std::optional<Tensor> v_cache_;
796+
std::optional<Tensor> output_;
797+
std::vector<float> scores_buf_;
798+
int64_t batch_ = 0;
799+
int64_t num_heads_q_ = 0;
800+
int64_t num_heads_kv_ = 0;
801+
int64_t head_dim_ = 0;
802+
int64_t max_seq_len_ = 0;
803+
int64_t start_pos_ = 0;
804+
int64_t q_seq_len_ = 0;
805+
bool is_transposed_ = false;
806+
};
807+
808+
BENCHMARK_DEFINE_F(OnnxGQABenchFixture, OnnxGQA)
809+
(benchmark::State& state) {
810+
const float* q_data = q_->const_data_ptr<float>();
811+
const float* k_data = k_cache_->const_data_ptr<float>();
812+
const float* v_data = v_cache_->const_data_ptr<float>();
813+
float* out_data = output_->mutable_data_ptr<float>();
814+
815+
for (auto _ : state) {
816+
run_onnx_gqa_sdpa(
817+
q_data,
818+
k_data,
819+
v_data,
820+
out_data,
821+
scores_buf_.data(),
822+
batch_,
823+
num_heads_q_,
824+
num_heads_kv_,
825+
head_dim_,
826+
max_seq_len_,
827+
start_pos_,
828+
q_seq_len_,
829+
is_transposed_);
830+
}
831+
}
832+
520833
/*
521834
* Benchmark configurations modeled after Llama 3 8B (GQA: 32 q heads, 8 kv
522835
* heads, head_dim=128). We test decode (seq_len=1) and prefill scenarios at
@@ -565,6 +878,33 @@ BENCHMARK_REGISTER_F(StandardSDPABenchFixture, StandardSDPA)
565878
->Args({1, 32, 32, 128, 2048, 256, 1, 1})
566879
->ArgNames({"B", "Hq", "Hkv", "D", "MaxS", "StartPos", "SeqLen", "Trans"});
567880

881+
// --- ONNX Runtime GQA-style SDPA ---
882+
// Same configs as StandardSDPA. Differences: scale-in-alpha, padded scores
883+
// buffer (ld=MaxS), narrow softmax, output in [B,S,Hq,D] with stride Hq*D.
884+
BENCHMARK_REGISTER_F(OnnxGQABenchFixture, OnnxGQA)
885+
// Standard layout decode at various cache positions
886+
->Args({1, 32, 8, 128, 2048, 0, 1, 0})
887+
->Args({1, 32, 8, 128, 2048, 64, 1, 0})
888+
->Args({1, 32, 8, 128, 2048, 256, 1, 0})
889+
->Args({1, 32, 8, 128, 2048, 512, 1, 0})
890+
->Args({1, 32, 8, 128, 2048, 1024, 1, 0})
891+
// Transposed layout decode at same positions
892+
->Args({1, 32, 8, 128, 2048, 0, 1, 1})
893+
->Args({1, 32, 8, 128, 2048, 64, 1, 1})
894+
->Args({1, 32, 8, 128, 2048, 256, 1, 1})
895+
->Args({1, 32, 8, 128, 2048, 512, 1, 1})
896+
->Args({1, 32, 8, 128, 2048, 1024, 1, 1})
897+
// Standard layout prefill
898+
->Args({1, 32, 8, 128, 2048, 0, 128, 0})
899+
->Args({1, 32, 8, 128, 2048, 0, 512, 0})
900+
// Transposed layout prefill
901+
->Args({1, 32, 8, 128, 2048, 0, 128, 1})
902+
->Args({1, 32, 8, 128, 2048, 0, 512, 1})
903+
// Llama 2 style (32 heads, no GQA)
904+
->Args({1, 32, 32, 128, 2048, 256, 1, 0})
905+
->Args({1, 32, 32, 128, 2048, 256, 1, 1})
906+
->ArgNames({"B", "Hq", "Hkv", "D", "MaxS", "StartPos", "SeqLen", "Trans"});
907+
568908
} // namespace
569909

570910
int main(int argc, char** argv) {

0 commit comments

Comments
 (0)