Skip to content

Commit 66b5c73

Browse files
committed
[Executorch] Add non-flash SDPA for decode
Add cpu_sdpa template function in op_sdpa_impl.h that provides a simpler SDPA implementation using standard GEMM (no tiling). This is useful as a baseline and for cases where flash attention is not optimal. The implementation uses a single SeqDim parameter for all tensors and supports causal masking, attention masks, GQA, and multi-threading. During decode (seq_len == 1), the tiled flash attention implementation has unnecessary overhead from its blocking/tiling logic. The simpler unfused SDPA path using direct GEMM is more efficient for single-query attention, yielding ~25-30% decode throughput improvement on S25 (41 -> 53 tok/s for 1.4B parameter model). This makes cpu_sdpa always available (previously gated behind ET_USE_UNFUSED_SDPA) and dispatches to it when seq_len == 1 and inputs are not quantized. Prefill continues to use flash attention. Differential Revision: [D96044318](https://our.internmc.facebook.com/intern/diff/D96044318/) ghstack-source-id: 361224785 Pull Request resolved: #18648
1 parent d142f79 commit 66b5c73

File tree

4 files changed

+509
-26
lines changed

4 files changed

+509
-26
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Tests for the unfused SDPA code path (cpu_sdpa) dispatched when
10+
// seq_len == 1 and inputs are non-quantized (the decode fast-path).
11+
// These call custom_sdpa_out directly, not through sdpa_with_kv_cache.
12+
13+
#include <algorithm>
14+
#include <cmath>
15+
#include <limits>
16+
#include <vector>
17+
18+
#include <executorch/extension/llm/custom_ops/op_sdpa.h>
19+
#include <executorch/kernels/test/TestUtil.h>
20+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
21+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
22+
23+
#include <gtest/gtest.h>
24+
25+
using namespace ::testing;
26+
using executorch::runtime::testing::TensorFactory;
27+
28+
namespace {
29+
30+
// Helper: call custom_sdpa_out. Inputs use [B, S, H, D] layout.
31+
executorch::aten::Tensor call_custom_sdpa(
32+
const executorch::aten::Tensor& q,
33+
const executorch::aten::Tensor& k,
34+
const executorch::aten::Tensor& v,
35+
int64_t start_pos,
36+
const std::optional<executorch::aten::Tensor>& attn_mask,
37+
double dropout_p,
38+
bool is_causal,
39+
std::optional<double> scale,
40+
executorch::aten::Tensor& out) {
41+
executorch::runtime::KernelRuntimeContext ctx{};
42+
return torch::executor::native::custom_sdpa_out(
43+
ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, out);
44+
}
45+
46+
/**
47+
* Naive reference SDPA for [B, S, H, D] layout.
48+
* Element [b,s,h,d] is at index b*S*H*D + s*H*D + h*D + d.
49+
* Only first num_valid_keys KV entries are used.
50+
*/
51+
void compute_reference_sdpa(
52+
const float* q_data,
53+
int B, int qS, int qH, int D,
54+
const float* k_data,
55+
int kvS, int kvH,
56+
const float* v_data,
57+
float* out_data,
58+
bool is_causal,
59+
int64_t start_pos,
60+
int num_valid_keys) {
61+
float scale = 1.0f / std::sqrt(static_cast<float>(D));
62+
int num_reps = qH / kvH;
63+
64+
for (int b = 0; b < B; b++) {
65+
for (int h = 0; h < qH; h++) {
66+
int kv_h = h / num_reps;
67+
for (int qs = 0; qs < qS; qs++) {
68+
// scores = Q @ K^T * scale
69+
std::vector<float> scores(num_valid_keys);
70+
for (int kvs = 0; kvs < num_valid_keys; kvs++) {
71+
float dot = 0;
72+
for (int d = 0; d < D; d++) {
73+
float qv = q_data[b*qS*qH*D + qs*qH*D + h*D + d];
74+
float kv = k_data[b*kvS*kvH*D + kvs*kvH*D + kv_h*D + d];
75+
dot += qv * kv;
76+
}
77+
scores[kvs] = dot * scale;
78+
}
79+
80+
// Causal mask
81+
if (is_causal) {
82+
int64_t valid = std::min(
83+
start_pos + qs + 1,
84+
static_cast<int64_t>(num_valid_keys));
85+
for (int64_t j = valid; j < num_valid_keys; j++) {
86+
scores[j] = -std::numeric_limits<float>::infinity();
87+
}
88+
}
89+
90+
// Softmax
91+
float max_val = *std::max_element(scores.begin(), scores.end());
92+
float sum = 0;
93+
for (auto& s : scores) {
94+
s = std::exp(s - max_val);
95+
sum += s;
96+
}
97+
if (sum > 0) {
98+
for (auto& s : scores) {
99+
s /= sum;
100+
}
101+
}
102+
103+
// output = scores @ V
104+
for (int d = 0; d < D; d++) {
105+
float val = 0;
106+
for (int kvs = 0; kvs < num_valid_keys; kvs++) {
107+
float vv = v_data[b*kvS*kvH*D + kvs*kvH*D + kv_h*D + d];
108+
val += scores[kvs] * vv;
109+
}
110+
out_data[b*qS*qH*D + qs*qH*D + h*D + d] = val;
111+
}
112+
}
113+
}
114+
}
115+
}
116+
117+
} // namespace
118+
119+
// With a single KV entry (start_pos=0), output must equal V[0].
120+
TEST(OpCustomSdpaTest, DecodeSingleKV) {
121+
TensorFactory<executorch::aten::ScalarType::Float> tf;
122+
123+
executorch::aten::Tensor q = tf.make(
124+
{1, 1, 2, 4},
125+
{0.8823, 0.9150, 0.3829, 0.9593,
126+
0.3904, 0.6009, 0.2566, 0.7936});
127+
128+
executorch::aten::Tensor k = tf.make(
129+
{1, 1, 2, 4},
130+
{0.8854, 0.5739, 0.2666, 0.6274,
131+
0.2696, 0.4414, 0.2969, 0.8317});
132+
133+
executorch::aten::Tensor v = tf.make(
134+
{1, 1, 2, 4},
135+
{0.6343, 0.3644, 0.7104, 0.9464,
136+
0.7890, 0.2814, 0.7886, 0.5895});
137+
138+
// softmax of a single score is always 1.0, so output == V
139+
executorch::aten::Tensor expected = tf.make(
140+
{1, 1, 2, 4},
141+
{0.6343, 0.3644, 0.7104, 0.9464,
142+
0.7890, 0.2814, 0.7886, 0.5895});
143+
144+
executorch::aten::Tensor out = tf.zeros({1, 1, 2, 4});
145+
call_custom_sdpa(q, k, v, /*start_pos=*/0, {}, 0.0, false, {}, out);
146+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-6, 1e-6);
147+
}
148+
149+
// Decode with 3 valid KV entries, verified against reference computation.
150+
TEST(OpCustomSdpaTest, DecodeNonCausal) {
151+
TensorFactory<executorch::aten::ScalarType::Float> tf;
152+
153+
// Q: [B=1, S=1, H=2, D=4]
154+
executorch::aten::Tensor q = tf.make(
155+
{1, 1, 2, 4},
156+
{0.8823, 0.9150, 0.3829, 0.9593,
157+
0.3904, 0.6009, 0.2566, 0.7936});
158+
159+
// K, V: [B=1, kv_len=4, H=2, D=4], first 3 entries valid
160+
executorch::aten::Tensor k = tf.make(
161+
{1, 4, 2, 4},
162+
{0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317,
163+
0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753,
164+
0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423,
165+
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});
166+
167+
executorch::aten::Tensor v = tf.make(
168+
{1, 4, 2, 4},
169+
{0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895,
170+
0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071,
171+
0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278,
172+
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});
173+
174+
int64_t start_pos = 2;
175+
int num_valid = 3;
176+
177+
std::vector<float> ref(8, 0.0f);
178+
compute_reference_sdpa(
179+
q.const_data_ptr<float>(), 1, 1, 2, 4,
180+
k.const_data_ptr<float>(), 4, 2,
181+
v.const_data_ptr<float>(),
182+
ref.data(), false, start_pos, num_valid);
183+
184+
executorch::aten::Tensor expected = tf.make({1, 1, 2, 4}, ref);
185+
executorch::aten::Tensor out = tf.zeros({1, 1, 2, 4});
186+
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out);
187+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-4, 1e-4);
188+
}
189+
190+
// GQA: 4 query heads sharing 2 KV heads.
191+
TEST(OpCustomSdpaTest, DecodeGQA) {
192+
TensorFactory<executorch::aten::ScalarType::Float> tf;
193+
194+
// Q: [B=1, S=1, H_q=4, D=4]
195+
executorch::aten::Tensor q = tf.make(
196+
{1, 1, 4, 4},
197+
{0.8823, 0.9150, 0.3829, 0.9593,
198+
0.3904, 0.6009, 0.2566, 0.7936,
199+
0.9408, 0.1332, 0.9346, 0.5936,
200+
0.8694, 0.5677, 0.7411, 0.4294});
201+
202+
// K: [B=1, kv_len=3, H_kv=2, D=4]
203+
executorch::aten::Tensor k = tf.make(
204+
{1, 3, 2, 4},
205+
{0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317,
206+
0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753,
207+
0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423});
208+
209+
// V: [B=1, kv_len=3, H_kv=2, D=4]
210+
executorch::aten::Tensor v = tf.make(
211+
{1, 3, 2, 4},
212+
{0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895,
213+
0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071,
214+
0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278});
215+
216+
int64_t start_pos = 2;
217+
int num_valid = 3;
218+
219+
std::vector<float> ref(16, 0.0f);
220+
compute_reference_sdpa(
221+
q.const_data_ptr<float>(), 1, 1, 4, 4,
222+
k.const_data_ptr<float>(), 3, 2,
223+
v.const_data_ptr<float>(),
224+
ref.data(), false, start_pos, num_valid);
225+
226+
executorch::aten::Tensor expected = tf.make({1, 1, 4, 4}, ref);
227+
executorch::aten::Tensor out = tf.zeros({1, 1, 4, 4});
228+
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out);
229+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-4, 1e-4);
230+
}
231+
232+
// For seq_len=1, causal mask doesn't restrict any positions
233+
// (all start_pos+1 entries are visible), so result must match non-causal.
234+
TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) {
235+
TensorFactory<executorch::aten::ScalarType::Float> tf;
236+
237+
executorch::aten::Tensor q = tf.make(
238+
{1, 1, 2, 4},
239+
{0.8823, 0.9150, 0.3829, 0.9593,
240+
0.3904, 0.6009, 0.2566, 0.7936});
241+
242+
executorch::aten::Tensor k = tf.make(
243+
{1, 4, 2, 4},
244+
{0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317,
245+
0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753,
246+
0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423,
247+
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});
248+
249+
executorch::aten::Tensor v = tf.make(
250+
{1, 4, 2, 4},
251+
{0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895,
252+
0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071,
253+
0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278,
254+
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000});
255+
256+
int64_t start_pos = 2;
257+
258+
executorch::aten::Tensor out_nc = tf.zeros({1, 1, 2, 4});
259+
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out_nc);
260+
261+
executorch::aten::Tensor out_c = tf.zeros({1, 1, 2, 4});
262+
call_custom_sdpa(q, k, v, start_pos, {}, 0.0, true, {}, out_c);
263+
264+
EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6);
265+
}

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -412,13 +412,19 @@ Tensor& custom_sdpa_out_impl(
412412
InvalidArgument,
413413
output);
414414

415-
// TODO(task): replace the template param selection logic
416-
// with whatever apprpriately makes more sense for
415+
bool use_unfused_sdpa = q.scalar_type() != ScalarType::Char &&
416+
seq_len == 1;
417+
if (use_unfused_sdpa) {
418+
ET_SWITCH_FLOAT_TYPES(
419+
output.scalar_type(), ctx, "sdpa", CTYPE, [&] {
420+
sdpa::impl::cpu_sdpa<CTYPE>(
421+
ctx, output, q, k, v, is_causal, attn_mask, scale,
422+
seq_dim,
423+
start_pos, num_keys_for_causal_attention);
424+
});
425+
} else {
417426
ET_SWITCH_FLOAT_TYPES(
418427
output.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
419-
// TODO we need to re-evaluate this for ARM CPUs
420-
// And there can be many so instead of templatizing
421-
// we might consider another appraoch
422428
if (seq_len >= 768) {
423429
sdpa::impl::cpu_flash_attention<CTYPE, 256, 512>(
424430
ctx,
@@ -430,13 +436,13 @@ Tensor& custom_sdpa_out_impl(
430436
is_causal,
431437
attn_mask,
432438
scale,
433-
q_zero_points, // q_zero_points
434-
q_scales, // q_scales
435-
k_zero_points, // k_zero_points
436-
k_scales, // k_scales
437-
v_zero_points, // v_zero_points
438-
v_scales, // v_scales
439-
seq_dim, /* seq_dim */
439+
q_zero_points,
440+
q_scales,
441+
k_zero_points,
442+
k_scales,
443+
v_zero_points,
444+
v_scales,
445+
seq_dim,
440446
start_pos,
441447
num_keys_for_causal_attention);
442448
} else if (seq_len >= 192) {
@@ -450,13 +456,13 @@ Tensor& custom_sdpa_out_impl(
450456
is_causal,
451457
attn_mask,
452458
scale,
453-
q_zero_points, // q_zero_points
454-
q_scales, // q_scales
455-
k_zero_points, // k_zero_points
456-
k_scales, // k_scales
457-
v_zero_points, // v_zero_points
458-
v_scales, // v_scales
459-
seq_dim, /* seq_dim */
459+
q_zero_points,
460+
q_scales,
461+
k_zero_points,
462+
k_scales,
463+
v_zero_points,
464+
v_scales,
465+
seq_dim,
460466
start_pos,
461467
num_keys_for_causal_attention);
462468
} else {
@@ -470,17 +476,18 @@ Tensor& custom_sdpa_out_impl(
470476
is_causal,
471477
attn_mask,
472478
scale,
473-
q_zero_points, // q_zero_points
474-
q_scales, // q_scales
475-
k_zero_points, // k_zero_points
476-
k_scales, // k_scales
477-
v_zero_points, // v_zero_points
478-
v_scales, // v_scales
479-
seq_dim, /* seq_dim */
479+
q_zero_points,
480+
q_scales,
481+
k_zero_points,
482+
k_scales,
483+
v_zero_points,
484+
v_scales,
485+
seq_dim,
480486
start_pos,
481487
num_keys_for_causal_attention);
482488
}
483489
});
490+
}
484491
return output;
485492
}
486493

0 commit comments

Comments
 (0)