Skip to content

Commit f3013bf

Browse files
committed
[Executorch] Add non-flash SDPA for decode
Pull Request resolved: #18648 Add cpu_sdpa template function in op_sdpa_impl.h that provides a simpler SDPA implementation using standard GEMM (no tiling). This is useful as a baseline and for cases where flash attention is not optimal. The implementation uses a single SeqDim parameter for all tensors and supports causal masking, attention masks, GQA, and multi-threading. During decode (seq_len == 1), the tiled flash attention implementation has unnecessary overhead from its blocking/tiling logic. The simpler unfused SDPA path using direct GEMM is more efficient for single-query attention, yielding ~25-30% decode throughput improvement on S25 (41 -> 53 tok/s for 1.4B parameter model). This makes cpu_sdpa always available (previously gated behind ET_USE_UNFUSED_SDPA) and dispatches to it when seq_len == 1 and inputs are not quantized. Prefill continues to use flash attention. ghstack-source-id: 374666319 @exported-using-ghexport Differential Revision: [D96044318](https://our.internmc.facebook.com/intern/diff/D96044318/)
1 parent 5cc3cde commit f3013bf

4 files changed

Lines changed: 593 additions & 69 deletions

File tree

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Tests for the unfused SDPA code path (cpu_sdpa) dispatched when
10+
// seq_len == 1 and inputs are non-quantized (the decode fast-path).
11+
// These call custom_sdpa_out directly, not through sdpa_with_kv_cache.
12+
13+
#include <algorithm>
14+
#include <cmath>
15+
#include <limits>
16+
#include <vector>
17+
18+
#include <executorch/extension/llm/custom_ops/op_sdpa.h>
19+
#include <executorch/kernels/test/TestUtil.h>
20+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
21+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
22+
23+
#include <gtest/gtest.h>
24+
25+
using namespace ::testing;
26+
using executorch::runtime::testing::TensorFactory;
27+
28+
namespace {
29+
30+
// Helper: call custom_sdpa_out. Inputs use [B, S, H, D] layout.
31+
executorch::aten::Tensor call_custom_sdpa(
32+
const executorch::aten::Tensor& q,
33+
const executorch::aten::Tensor& k,
34+
const executorch::aten::Tensor& v,
35+
int64_t start_pos,
36+
const std::optional<executorch::aten::Tensor>& attn_mask,
37+
double dropout_p,
38+
bool is_causal,
39+
std::optional<double> scale,
40+
executorch::aten::Tensor& out) {
41+
executorch::runtime::KernelRuntimeContext ctx{};
42+
return torch::executor::native::custom_sdpa_out(
43+
ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, out);
44+
}
45+
46+
/**
47+
* Naive reference SDPA for [B, S, H, D] layout.
48+
* Element [b,s,h,d] is at index b*S*H*D + s*H*D + h*D + d.
49+
* Only first num_valid_keys KV entries are used.
50+
*/
51+
void compute_reference_sdpa(
52+
const float* q_data,
53+
int B,
54+
int qS,
55+
int qH,
56+
int D,
57+
const float* k_data,
58+
int kvS,
59+
int kvH,
60+
const float* v_data,
61+
float* out_data,
62+
bool is_causal,
63+
int64_t start_pos,
64+
int num_valid_keys) {
65+
float scale = 1.0f / std::sqrt(static_cast<float>(D));
66+
int num_reps = qH / kvH;
67+
68+
for (int b = 0; b < B; b++) {
69+
for (int h = 0; h < qH; h++) {
70+
int kv_h = h / num_reps;
71+
for (int qs = 0; qs < qS; qs++) {
72+
// scores = Q @ K^T * scale
73+
std::vector<float> scores(num_valid_keys);
74+
for (int kvs = 0; kvs < num_valid_keys; kvs++) {
75+
float dot = 0;
76+
for (int d = 0; d < D; d++) {
77+
float qv = q_data[b * qS * qH * D + qs * qH * D + h * D + d];
78+
float kv = k_data[b * kvS * kvH * D + kvs * kvH * D + kv_h * D + d];
79+
dot += qv * kv;
80+
}
81+
scores[kvs] = dot * scale;
82+
}
83+
84+
// Causal mask
85+
if (is_causal) {
86+
int64_t valid = std::min(
87+
start_pos + qs + 1, static_cast<int64_t>(num_valid_keys));
88+
for (int64_t j = valid; j < num_valid_keys; j++) {
89+
scores[j] = -std::numeric_limits<float>::infinity();
90+
}
91+
}
92+
93+
// Softmax
94+
float max_val = *std::max_element(scores.begin(), scores.end());
95+
float sum = 0;
96+
for (auto& s : scores) {
97+
s = std::exp(s - max_val);
98+
sum += s;
99+
}
100+
if (sum > 0) {
101+
for (auto& s : scores) {
102+
s /= sum;
103+
}
104+
}
105+
106+
// output = scores @ V
107+
for (int d = 0; d < D; d++) {
108+
float val = 0;
109+
for (int kvs = 0; kvs < num_valid_keys; kvs++) {
110+
float vv = v_data[b * kvS * kvH * D + kvs * kvH * D + kv_h * D + d];
111+
val += scores[kvs] * vv;
112+
}
113+
out_data[b * qS * qH * D + qs * qH * D + h * D + d] = val;
114+
}
115+
}
116+
}
117+
}
118+
}
119+
120+
} // namespace
121+
122+
// With a single KV entry (start_pos=0), output must equal V[0].
123+
TEST(OpCustomSdpaTest, DecodeSingleKV) {
124+
TensorFactory<executorch::aten::ScalarType::Float> tf;
125+
126+
executorch::aten::Tensor q = tf.make(
127+
{1, 1, 2, 4},
128+
{0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936});
129+
130+
executorch::aten::Tensor k = tf.make(
131+
{1, 1, 2, 4},
132+
{0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317});
133+
134+
executorch::aten::Tensor v = tf.make(
135+
{1, 1, 2, 4},
136+
{0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895});
137+
138+
// softmax of a single score is always 1.0, so output == V
139+
executorch::aten::Tensor expected = tf.make(
140+
{1, 1, 2, 4},
141+
{0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895});
142+
143+
executorch::aten::Tensor out = tf.zeros({1, 1, 2, 4});
144+
call_custom_sdpa(q, k, v, /*start_pos=*/0, {}, 0.0, false, {}, out);
145+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-6, 1e-6);
146+
}
147+
148+
// Decode with 3 valid KV entries, verified against reference computation.
149+
TEST(OpCustomSdpaTest, DecodeNonCausal) {
150+
TensorFactory<executorch::aten::ScalarType::Float> tf;
151+
152+
// Q: [B=1, S=1, H=2, D=4]
153+
executorch::aten::Tensor q = tf.make(
154+
{1, 1, 2, 4},
155+
{0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936});
156+
157+
// K, V: [B=1, kv_len=4, H=2, D=4], first 3 entries valid
158+
executorch::aten::Tensor k = tf.make(
159+
{1, 4, 2, 4},
160+
{0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317,
161+
0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753,
162+
0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423,
163+
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});
164+
165+
executorch::aten::Tensor v = tf.make(
166+
{1, 4, 2, 4},
167+
{0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895,
168+
0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071,
169+
0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278,
170+
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});
171+
172+
int64_t start_pos = 2;
173+
int num_valid = 3;
174+
175+
std::vector<float> ref(8, 0.0f);
176+
compute_reference_sdpa(
177+
q.const_data_ptr<float>(),
178+
1,
179+
1,
180+
2,
181+
4,
182+
k.const_data_ptr<float>(),
183+
4,
184+
2,
185+
v.const_data_ptr<float>(),
186+
ref.data(),
187+
false,
188+
start_pos,
189+
num_valid);
190+
191+
executorch::aten::Tensor expected = tf.make({1, 1, 2, 4}, ref);
192+
executorch::aten::Tensor out = tf.zeros({1, 1, 2, 4});
193+
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out);
194+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-4, 1e-4);
195+
}
196+
197+
// GQA: 4 query heads sharing 2 KV heads.
198+
TEST(OpCustomSdpaTest, DecodeGQA) {
199+
TensorFactory<executorch::aten::ScalarType::Float> tf;
200+
201+
// Q: [B=1, S=1, H_q=4, D=4]
202+
executorch::aten::Tensor q = tf.make(
203+
{1, 1, 4, 4},
204+
{0.8823,
205+
0.9150,
206+
0.3829,
207+
0.9593,
208+
0.3904,
209+
0.6009,
210+
0.2566,
211+
0.7936,
212+
0.9408,
213+
0.1332,
214+
0.9346,
215+
0.5936,
216+
0.8694,
217+
0.5677,
218+
0.7411,
219+
0.4294});
220+
221+
// K: [B=1, kv_len=3, H_kv=2, D=4]
222+
executorch::aten::Tensor k =
223+
tf.make({1, 3, 2, 4}, {0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414,
224+
0.2969, 0.8317, 0.1053, 0.2695, 0.3588, 0.1994,
225+
0.5472, 0.0062, 0.9516, 0.0753, 0.8860, 0.5832,
226+
0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423});
227+
228+
// V: [B=1, kv_len=3, H_kv=2, D=4]
229+
executorch::aten::Tensor v =
230+
tf.make({1, 3, 2, 4}, {0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814,
231+
0.7886, 0.5895, 0.7539, 0.1952, 0.0050, 0.3068,
232+
0.1165, 0.9103, 0.6440, 0.7071, 0.6581, 0.4913,
233+
0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278});
234+
235+
int64_t start_pos = 2;
236+
int num_valid = 3;
237+
238+
std::vector<float> ref(16, 0.0f);
239+
compute_reference_sdpa(
240+
q.const_data_ptr<float>(),
241+
1,
242+
1,
243+
4,
244+
4,
245+
k.const_data_ptr<float>(),
246+
3,
247+
2,
248+
v.const_data_ptr<float>(),
249+
ref.data(),
250+
false,
251+
start_pos,
252+
num_valid);
253+
254+
executorch::aten::Tensor expected = tf.make({1, 1, 4, 4}, ref);
255+
executorch::aten::Tensor out = tf.zeros({1, 1, 4, 4});
256+
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out);
257+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-4, 1e-4);
258+
}
259+
260+
// For seq_len=1, causal mask doesn't restrict any positions
261+
// (all start_pos+1 entries are visible), so result must match non-causal.
262+
TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) {
263+
TensorFactory<executorch::aten::ScalarType::Float> tf;
264+
265+
executorch::aten::Tensor q = tf.make(
266+
{1, 1, 2, 4},
267+
{0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936});
268+
269+
executorch::aten::Tensor k = tf.make(
270+
{1, 4, 2, 4},
271+
{0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317,
272+
0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753,
273+
0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423,
274+
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});
275+
276+
executorch::aten::Tensor v = tf.make(
277+
{1, 4, 2, 4},
278+
{0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895,
279+
0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071,
280+
0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278,
281+
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});
282+
283+
int64_t start_pos = 2;
284+
285+
executorch::aten::Tensor out_nc = tf.zeros({1, 1, 2, 4});
286+
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out_nc);
287+
288+
executorch::aten::Tensor out_c = tf.zeros({1, 1, 2, 4});
289+
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, true, {}, out_c);
290+
291+
EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6);
292+
}

0 commit comments

Comments
 (0)