Skip to content

Commit d142f79

Browse files
committed
Add ONNX Runtime GQA-style SDPA benchmark
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) ghstack-source-id: 361224786 Pull Request resolved: #18647
1 parent 4a828d4 commit d142f79

File tree

1 file changed

+278
-5
lines changed

1 file changed

+278
-5
lines changed

extension/llm/custom_ops/bench_sdpa.cpp

Lines changed: 278 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,128 @@ void run_standard_sdpa(
147147
});
148148
}
149149

150+
// ONNX Runtime GQA-style SDPA, faithfully ported from
151+
// onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h.
152+
// Differences from run_standard_sdpa:
153+
// 1. Scale in GEMM alpha (no separate scaling pass)
154+
// 2. Scores buffer padded to max_seq_len cols (ONNX's present_buffer_seq_len)
155+
// 3. Causal mask: zero out future positions, softmax on valid window only
156+
// 4. Output in [B, S, Hq, D] with stride Hq*D (ONNX's interleaved BNSH->BSNH)
157+
//
158+
// When is_transposed=true, inputs are [B,H,S,D]; output is [B,S,Hq,D].
159+
// When is_transposed=false, inputs are [B,S,H,D]; output is [B,S,Hq,D].
160+
// Output is always [B, S, Hq, D] to match ONNX's actual output format.
161+
void run_onnx_gqa_sdpa(
162+
const float* q_data,
163+
const float* k_data,
164+
const float* v_data,
165+
float* out_data, // always [B, q_seq_len, Hq, D]
166+
float* scores_buf, // must hold batch*Hq*q_seq_len*max_seq_len floats
167+
int64_t batch,
168+
int64_t Hq,
169+
int64_t Hkv,
170+
int64_t D,
171+
int64_t max_seq_len,
172+
int64_t start_pos,
173+
int64_t q_seq_len,
174+
bool is_transposed) {
175+
using executorch::cpublas::TransposeType;
176+
177+
const int64_t total_seqlen = start_pos + q_seq_len;
178+
const float alpha = 1.0f / std::sqrt(static_cast<float>(D));
179+
const int64_t heads_per_group = Hq / Hkv;
180+
const int64_t hidden_size = Hq * D; // output row stride (ONNX convention)
181+
182+
// Input strides depend on layout
183+
const int64_t ldq = is_transposed ? D : Hq * D;
184+
const int64_t ldk = is_transposed ? D : Hkv * D;
185+
const int64_t ldv = is_transposed ? D : Hkv * D;
186+
// Output is always [B, S, Hq, D] so ldo = Hq * D = hidden_size
187+
const int64_t ldo = hidden_size;
188+
189+
torch::executor::parallel_for(
190+
0, batch * Hq, 1, [&](int64_t begin, int64_t end) {
191+
for (int64_t idx = begin; idx < end; ++idx) {
192+
const int64_t b = idx / Hq;
193+
const int64_t h = idx % Hq;
194+
const int64_t kv_h = h / heads_per_group;
195+
196+
const float* q_ptr;
197+
const float* k_ptr;
198+
const float* v_ptr;
199+
if (is_transposed) {
200+
q_ptr = q_data + (b * Hq + h) * q_seq_len * D;
201+
k_ptr = k_data + (b * Hkv + kv_h) * max_seq_len * D;
202+
v_ptr = v_data + (b * Hkv + kv_h) * max_seq_len * D;
203+
} else {
204+
q_ptr = q_data + b * q_seq_len * Hq * D + h * D;
205+
k_ptr = k_data + b * max_seq_len * Hkv * D + kv_h * D;
206+
v_ptr = v_data + b * max_seq_len * Hkv * D + kv_h * D;
207+
}
208+
// Output always [B, S, Hq, D]: head h writes at stride hidden_size
209+
float* out_ptr =
210+
out_data + b * q_seq_len * hidden_size + h * D;
211+
212+
// Scores padded to max_seq_len columns (ONNX convention)
213+
float* scores = scores_buf + idx * q_seq_len * max_seq_len;
214+
215+
// GEMM 1: Q @ K^T with scale in alpha
216+
executorch::cpublas::gemm(
217+
TransposeType::Transpose, TransposeType::NoTranspose,
218+
total_seqlen, q_seq_len, D,
219+
alpha, k_ptr, ldk, q_ptr, ldq,
220+
0.0f, scores, max_seq_len);
221+
222+
// Causal mask + narrow softmax (ONNX style):
223+
// Zero future positions, softmax only on valid [0, causal_len).
224+
for (int64_t qi = 0; qi < q_seq_len; ++qi) {
225+
float* row = scores + qi * max_seq_len;
226+
const int64_t causal_len =
227+
std::min(start_pos + qi + 1, total_seqlen);
228+
229+
for (int64_t j = causal_len; j < total_seqlen; ++j) {
230+
row[j] = 0.0f;
231+
}
232+
233+
float max_val = row[0];
234+
for (int64_t j = 1; j < causal_len; ++j) {
235+
max_val = std::max(max_val, row[j]);
236+
}
237+
float sum = 0.0f;
238+
for (int64_t j = 0; j < causal_len; ++j) {
239+
row[j] = std::exp(row[j] - max_val);
240+
sum += row[j];
241+
}
242+
const float inv_sum = 1.0f / sum;
243+
for (int64_t j = 0; j < causal_len; ++j) {
244+
row[j] *= inv_sum;
245+
}
246+
}
247+
248+
// GEMM 2: scores @ V -> output
249+
executorch::cpublas::gemm(
250+
TransposeType::NoTranspose, TransposeType::NoTranspose,
251+
D, q_seq_len, total_seqlen,
252+
1.0f, v_ptr, ldv, scores, max_seq_len,
253+
0.0f, out_ptr, ldo);
254+
}
255+
});
256+
}
257+
150258
// Return max |a - b| across all elements.
151-
float max_abs_diff(const Tensor& a, const Tensor& b) {
152-
const float* a_data = a.const_data_ptr<float>();
153-
const float* b_data = b.const_data_ptr<float>();
259+
float max_abs_diff(const float* a, const float* b, int64_t n) {
154260
float d = 0.0f;
155-
for (int64_t i = 0; i < a.numel(); ++i) {
156-
d = std::max(d, std::abs(a_data[i] - b_data[i]));
261+
for (int64_t i = 0; i < n; ++i) {
262+
d = std::max(d, std::abs(a[i] - b[i]));
157263
}
158264
return d;
159265
}
160266

267+
float max_abs_diff(const Tensor& a, const Tensor& b) {
268+
return max_abs_diff(
269+
a.const_data_ptr<float>(), b.const_data_ptr<float>(), a.numel());
270+
}
271+
161272
// Validate a single config: run StandardSDPA and custom_sdpa_out on the same
162273
// inputs, check outputs match within tolerance. Returns false on mismatch.
163274
// Only tests standard [B,S,H,D] layout (is_transposed=false).
@@ -226,6 +337,45 @@ bool validate_config(
226337
mode, (long)batch, (long)Hq, (long)Hkv, (long)D,
227338
(long)start_pos, (long)q_seq_len, diff);
228339

340+
// Also validate ONNX GQA variant. Output is always [B, S, Hq, D].
341+
// Since we only test standard [B,S,H,D] layout, out_ref is already
342+
// [B,S,Hq,D] — just copy directly to ref_bshd (no transpose needed).
343+
Tensor out_onnx =
344+
tf.zeros({(int32_t)batch, (int32_t)q_seq_len, (int32_t)Hq, (int32_t)D});
345+
std::vector<float> onnx_scores_buf(batch * Hq * q_seq_len * max_seq_len);
346+
run_onnx_gqa_sdpa(
347+
q.const_data_ptr<float>(),
348+
k.const_data_ptr<float>(),
349+
v.const_data_ptr<float>(),
350+
out_onnx.mutable_data_ptr<float>(),
351+
onnx_scores_buf.data(),
352+
batch, Hq, Hkv, D, max_seq_len, start_pos, q_seq_len,
353+
false /* is_transposed */);
354+
355+
// out_ref is already [B, S, Hq, D] (standard layout), compare directly
356+
std::vector<float> ref_bshd(batch * q_seq_len * Hq * D);
357+
const float* ref_ptr = out_ref.const_data_ptr<float>();
358+
std::copy(ref_ptr, ref_ptr + batch * q_seq_len * Hq * D, ref_bshd.data());
359+
360+
float onnx_diff = max_abs_diff(
361+
out_onnx.const_data_ptr<float>(), ref_bshd.data(),
362+
batch * q_seq_len * Hq * D);
363+
if (onnx_diff > atol) {
364+
fprintf(
365+
stderr,
366+
"FAIL: OnnxGQA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
367+
"max_abs_diff=%.6e > atol=%.6e\n",
368+
mode, (long)batch, (long)Hq, (long)Hkv, (long)D,
369+
(long)start_pos, (long)q_seq_len, onnx_diff, atol);
370+
return false;
371+
}
372+
fprintf(
373+
stderr,
374+
"PASS: OnnxGQA standard %s (B=%ld Hq=%ld Hkv=%ld D=%ld sp=%ld sl=%ld) "
375+
"max_abs_diff=%.6e\n",
376+
mode, (long)batch, (long)Hq, (long)Hkv, (long)D,
377+
(long)start_pos, (long)q_seq_len, onnx_diff);
378+
229379
return true;
230380
}
231381

@@ -429,6 +579,101 @@ BENCHMARK_DEFINE_F(StandardSDPABenchFixture, StandardSDPA)
429579
}
430580
}
431581

582+
// ONNX Runtime GQA-style benchmark. Faithfully matches the algorithm from
583+
// gqa_attention_base.h: scale-in-alpha, padded scores buffer, narrow softmax,
584+
// and output in [B, S, Hq, D] with stride Hq*D.
585+
class OnnxGQABenchFixture : public benchmark::Fixture {
586+
public:
587+
// Args: {batch, num_heads_q, num_heads_kv, head_dim, max_seq_len, start_pos,
588+
// query_seq_len, is_transposed}
589+
void SetUp(benchmark::State& state) override {
590+
int64_t batch = state.range(0);
591+
int64_t num_heads_q = state.range(1);
592+
int64_t num_heads_kv = state.range(2);
593+
int64_t head_dim = state.range(3);
594+
int64_t max_seq_len = state.range(4);
595+
int64_t start_pos = state.range(5);
596+
int64_t q_seq_len = state.range(6);
597+
bool is_transposed = state.range(7) != 0;
598+
599+
std::mt19937 gen(42);
600+
601+
if (is_transposed) {
602+
q_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_q,
603+
(int32_t)q_seq_len, (int32_t)head_dim}));
604+
k_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_kv,
605+
(int32_t)max_seq_len, (int32_t)head_dim}));
606+
v_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)num_heads_kv,
607+
(int32_t)max_seq_len, (int32_t)head_dim}));
608+
} else {
609+
q_.emplace(tf_.zeros({(int32_t)batch, (int32_t)q_seq_len,
610+
(int32_t)num_heads_q, (int32_t)head_dim}));
611+
k_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)max_seq_len,
612+
(int32_t)num_heads_kv, (int32_t)head_dim}));
613+
v_cache_.emplace(tf_.zeros({(int32_t)batch, (int32_t)max_seq_len,
614+
(int32_t)num_heads_kv, (int32_t)head_dim}));
615+
}
616+
// Output always [B, S, Hq, D] (ONNX convention)
617+
output_.emplace(tf_.zeros({(int32_t)batch, (int32_t)q_seq_len,
618+
(int32_t)num_heads_q, (int32_t)head_dim}));
619+
620+
fill_random(*q_, gen);
621+
fill_random(*k_cache_, gen);
622+
fill_random(*v_cache_, gen);
623+
624+
batch_ = batch;
625+
num_heads_q_ = num_heads_q;
626+
num_heads_kv_ = num_heads_kv;
627+
head_dim_ = head_dim;
628+
max_seq_len_ = max_seq_len;
629+
start_pos_ = start_pos;
630+
q_seq_len_ = q_seq_len;
631+
is_transposed_ = is_transposed;
632+
633+
// Scores buffer padded to max_seq_len columns (ONNX convention)
634+
int64_t total_units = batch * num_heads_q;
635+
scores_buf_.resize(total_units * q_seq_len * max_seq_len);
636+
}
637+
638+
void TearDown(benchmark::State&) override {
639+
q_.reset();
640+
k_cache_.reset();
641+
v_cache_.reset();
642+
output_.reset();
643+
scores_buf_.clear();
644+
}
645+
646+
TensorFactory<ScalarType::Float> tf_;
647+
std::optional<Tensor> q_;
648+
std::optional<Tensor> k_cache_;
649+
std::optional<Tensor> v_cache_;
650+
std::optional<Tensor> output_;
651+
std::vector<float> scores_buf_;
652+
int64_t batch_ = 0;
653+
int64_t num_heads_q_ = 0;
654+
int64_t num_heads_kv_ = 0;
655+
int64_t head_dim_ = 0;
656+
int64_t max_seq_len_ = 0;
657+
int64_t start_pos_ = 0;
658+
int64_t q_seq_len_ = 0;
659+
bool is_transposed_ = false;
660+
};
661+
662+
BENCHMARK_DEFINE_F(OnnxGQABenchFixture, OnnxGQA)
663+
(benchmark::State& state) {
664+
const float* q_data = q_->const_data_ptr<float>();
665+
const float* k_data = k_cache_->const_data_ptr<float>();
666+
const float* v_data = v_cache_->const_data_ptr<float>();
667+
float* out_data = output_->mutable_data_ptr<float>();
668+
669+
for (auto _ : state) {
670+
run_onnx_gqa_sdpa(
671+
q_data, k_data, v_data, out_data, scores_buf_.data(),
672+
batch_, num_heads_q_, num_heads_kv_, head_dim_,
673+
max_seq_len_, start_pos_, q_seq_len_, is_transposed_);
674+
}
675+
}
676+
432677
/*
433678
* Benchmark configurations modeled after Llama 3 8B (GQA: 32 q heads, 8 kv
434679
* heads, head_dim=128). We test decode (seq_len=1) and prefill scenarios at
@@ -478,6 +723,34 @@ BENCHMARK_REGISTER_F(StandardSDPABenchFixture, StandardSDPA)
478723
->ArgNames(
479724
{"B", "Hq", "Hkv", "D", "MaxS", "StartPos", "SeqLen", "Trans"});
480725

726+
// --- ONNX Runtime GQA-style SDPA ---
727+
// Same configs as StandardSDPA. Differences: scale-in-alpha, padded scores
728+
// buffer (ld=MaxS), narrow softmax, output in [B,S,Hq,D] with stride Hq*D.
729+
BENCHMARK_REGISTER_F(OnnxGQABenchFixture, OnnxGQA)
730+
// Standard layout decode at various cache positions
731+
->Args({1, 32, 8, 128, 2048, 0, 1, 0})
732+
->Args({1, 32, 8, 128, 2048, 64, 1, 0})
733+
->Args({1, 32, 8, 128, 2048, 256, 1, 0})
734+
->Args({1, 32, 8, 128, 2048, 512, 1, 0})
735+
->Args({1, 32, 8, 128, 2048, 1024, 1, 0})
736+
// Transposed layout decode at same positions
737+
->Args({1, 32, 8, 128, 2048, 0, 1, 1})
738+
->Args({1, 32, 8, 128, 2048, 64, 1, 1})
739+
->Args({1, 32, 8, 128, 2048, 256, 1, 1})
740+
->Args({1, 32, 8, 128, 2048, 512, 1, 1})
741+
->Args({1, 32, 8, 128, 2048, 1024, 1, 1})
742+
// Standard layout prefill
743+
->Args({1, 32, 8, 128, 2048, 0, 128, 0})
744+
->Args({1, 32, 8, 128, 2048, 0, 512, 0})
745+
// Transposed layout prefill
746+
->Args({1, 32, 8, 128, 2048, 0, 128, 1})
747+
->Args({1, 32, 8, 128, 2048, 0, 512, 1})
748+
// Llama 2 style (32 heads, no GQA)
749+
->Args({1, 32, 32, 128, 2048, 256, 1, 0})
750+
->Args({1, 32, 32, 128, 2048, 256, 1, 1})
751+
->ArgNames(
752+
{"B", "Hq", "Hkv", "D", "MaxS", "StartPos", "SeqLen", "Trans"});
753+
481754
} // namespace
482755

483756
int main(int argc, char** argv) {

0 commit comments

Comments
 (0)