|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +// Tests for the unfused SDPA code path (cpu_sdpa) dispatched when |
| 10 | +// seq_len == 1 and inputs are non-quantized (the decode fast-path). |
| 11 | +// These call custom_sdpa_out directly, not through sdpa_with_kv_cache. |
| 12 | + |
| 13 | +#include <algorithm> |
| 14 | +#include <cmath> |
| 15 | +#include <limits> |
| 16 | +#include <vector> |
| 17 | + |
| 18 | +#include <executorch/extension/llm/custom_ops/op_sdpa.h> |
| 19 | +#include <executorch/kernels/test/TestUtil.h> |
| 20 | +#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h> |
| 21 | +#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h> |
| 22 | + |
| 23 | +#include <gtest/gtest.h> |
| 24 | + |
| 25 | +using namespace ::testing; |
| 26 | +using executorch::runtime::testing::TensorFactory; |
| 27 | + |
| 28 | +namespace { |
| 29 | + |
| 30 | +// Helper: call custom_sdpa_out. Inputs use [B, S, H, D] layout. |
| 31 | +executorch::aten::Tensor call_custom_sdpa( |
| 32 | + const executorch::aten::Tensor& q, |
| 33 | + const executorch::aten::Tensor& k, |
| 34 | + const executorch::aten::Tensor& v, |
| 35 | + int64_t start_pos, |
| 36 | + const std::optional<executorch::aten::Tensor>& attn_mask, |
| 37 | + double dropout_p, |
| 38 | + bool is_causal, |
| 39 | + std::optional<double> scale, |
| 40 | + executorch::aten::Tensor& out) { |
| 41 | + executorch::runtime::KernelRuntimeContext ctx{}; |
| 42 | + return torch::executor::native::custom_sdpa_out( |
| 43 | + ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, out); |
| 44 | +} |
| 45 | + |
| 46 | +/** |
| 47 | + * Naive reference SDPA for [B, S, H, D] layout. |
| 48 | + * Element [b,s,h,d] is at index b*S*H*D + s*H*D + h*D + d. |
| 49 | + * Only first num_valid_keys KV entries are used. |
| 50 | + */ |
| 51 | +void compute_reference_sdpa( |
| 52 | + const float* q_data, |
| 53 | + int B, int qS, int qH, int D, |
| 54 | + const float* k_data, |
| 55 | + int kvS, int kvH, |
| 56 | + const float* v_data, |
| 57 | + float* out_data, |
| 58 | + bool is_causal, |
| 59 | + int64_t start_pos, |
| 60 | + int num_valid_keys) { |
| 61 | + float scale = 1.0f / std::sqrt(static_cast<float>(D)); |
| 62 | + int num_reps = qH / kvH; |
| 63 | + |
| 64 | + for (int b = 0; b < B; b++) { |
| 65 | + for (int h = 0; h < qH; h++) { |
| 66 | + int kv_h = h / num_reps; |
| 67 | + for (int qs = 0; qs < qS; qs++) { |
| 68 | + // scores = Q @ K^T * scale |
| 69 | + std::vector<float> scores(num_valid_keys); |
| 70 | + for (int kvs = 0; kvs < num_valid_keys; kvs++) { |
| 71 | + float dot = 0; |
| 72 | + for (int d = 0; d < D; d++) { |
| 73 | + float qv = q_data[b*qS*qH*D + qs*qH*D + h*D + d]; |
| 74 | + float kv = k_data[b*kvS*kvH*D + kvs*kvH*D + kv_h*D + d]; |
| 75 | + dot += qv * kv; |
| 76 | + } |
| 77 | + scores[kvs] = dot * scale; |
| 78 | + } |
| 79 | + |
| 80 | + // Causal mask |
| 81 | + if (is_causal) { |
| 82 | + int64_t valid = std::min( |
| 83 | + start_pos + qs + 1, |
| 84 | + static_cast<int64_t>(num_valid_keys)); |
| 85 | + for (int64_t j = valid; j < num_valid_keys; j++) { |
| 86 | + scores[j] = -std::numeric_limits<float>::infinity(); |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + // Softmax |
| 91 | + float max_val = *std::max_element(scores.begin(), scores.end()); |
| 92 | + float sum = 0; |
| 93 | + for (auto& s : scores) { |
| 94 | + s = std::exp(s - max_val); |
| 95 | + sum += s; |
| 96 | + } |
| 97 | + if (sum > 0) { |
| 98 | + for (auto& s : scores) { |
| 99 | + s /= sum; |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + // output = scores @ V |
| 104 | + for (int d = 0; d < D; d++) { |
| 105 | + float val = 0; |
| 106 | + for (int kvs = 0; kvs < num_valid_keys; kvs++) { |
| 107 | + float vv = v_data[b*kvS*kvH*D + kvs*kvH*D + kv_h*D + d]; |
| 108 | + val += scores[kvs] * vv; |
| 109 | + } |
| 110 | + out_data[b*qS*qH*D + qs*qH*D + h*D + d] = val; |
| 111 | + } |
| 112 | + } |
| 113 | + } |
| 114 | + } |
| 115 | +} |
| 116 | + |
| 117 | +} // namespace |
| 118 | + |
| 119 | +// With a single KV entry (start_pos=0), output must equal V[0]. |
| 120 | +TEST(OpCustomSdpaTest, DecodeSingleKV) { |
| 121 | + TensorFactory<executorch::aten::ScalarType::Float> tf; |
| 122 | + |
| 123 | + executorch::aten::Tensor q = tf.make( |
| 124 | + {1, 1, 2, 4}, |
| 125 | + {0.8823, 0.9150, 0.3829, 0.9593, |
| 126 | + 0.3904, 0.6009, 0.2566, 0.7936}); |
| 127 | + |
| 128 | + executorch::aten::Tensor k = tf.make( |
| 129 | + {1, 1, 2, 4}, |
| 130 | + {0.8854, 0.5739, 0.2666, 0.6274, |
| 131 | + 0.2696, 0.4414, 0.2969, 0.8317}); |
| 132 | + |
| 133 | + executorch::aten::Tensor v = tf.make( |
| 134 | + {1, 1, 2, 4}, |
| 135 | + {0.6343, 0.3644, 0.7104, 0.9464, |
| 136 | + 0.7890, 0.2814, 0.7886, 0.5895}); |
| 137 | + |
| 138 | + // softmax of a single score is always 1.0, so output == V |
| 139 | + executorch::aten::Tensor expected = tf.make( |
| 140 | + {1, 1, 2, 4}, |
| 141 | + {0.6343, 0.3644, 0.7104, 0.9464, |
| 142 | + 0.7890, 0.2814, 0.7886, 0.5895}); |
| 143 | + |
| 144 | + executorch::aten::Tensor out = tf.zeros({1, 1, 2, 4}); |
| 145 | + call_custom_sdpa(q, k, v, /*start_pos=*/0, {}, 0.0, false, {}, out); |
| 146 | + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-6, 1e-6); |
| 147 | +} |
| 148 | + |
| 149 | +// Decode with 3 valid KV entries, verified against reference computation. |
| 150 | +TEST(OpCustomSdpaTest, DecodeNonCausal) { |
| 151 | + TensorFactory<executorch::aten::ScalarType::Float> tf; |
| 152 | + |
| 153 | + // Q: [B=1, S=1, H=2, D=4] |
| 154 | + executorch::aten::Tensor q = tf.make( |
| 155 | + {1, 1, 2, 4}, |
| 156 | + {0.8823, 0.9150, 0.3829, 0.9593, |
| 157 | + 0.3904, 0.6009, 0.2566, 0.7936}); |
| 158 | + |
| 159 | + // K, V: [B=1, kv_len=4, H=2, D=4], first 3 entries valid |
| 160 | + executorch::aten::Tensor k = tf.make( |
| 161 | + {1, 4, 2, 4}, |
| 162 | + {0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317, |
| 163 | + 0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753, |
| 164 | + 0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423, |
| 165 | + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000}); |
| 166 | + |
| 167 | + executorch::aten::Tensor v = tf.make( |
| 168 | + {1, 4, 2, 4}, |
| 169 | + {0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895, |
| 170 | + 0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071, |
| 171 | + 0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278, |
| 172 | + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000}); |
| 173 | + |
| 174 | + int64_t start_pos = 2; |
| 175 | + int num_valid = 3; |
| 176 | + |
| 177 | + std::vector<float> ref(8, 0.0f); |
| 178 | + compute_reference_sdpa( |
| 179 | + q.const_data_ptr<float>(), 1, 1, 2, 4, |
| 180 | + k.const_data_ptr<float>(), 4, 2, |
| 181 | + v.const_data_ptr<float>(), |
| 182 | + ref.data(), false, start_pos, num_valid); |
| 183 | + |
| 184 | + executorch::aten::Tensor expected = tf.make({1, 1, 2, 4}, ref); |
| 185 | + executorch::aten::Tensor out = tf.zeros({1, 1, 2, 4}); |
| 186 | + call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out); |
| 187 | + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-4, 1e-4); |
| 188 | +} |
| 189 | + |
| 190 | +// GQA: 4 query heads sharing 2 KV heads. |
| 191 | +TEST(OpCustomSdpaTest, DecodeGQA) { |
| 192 | + TensorFactory<executorch::aten::ScalarType::Float> tf; |
| 193 | + |
| 194 | + // Q: [B=1, S=1, H_q=4, D=4] |
| 195 | + executorch::aten::Tensor q = tf.make( |
| 196 | + {1, 1, 4, 4}, |
| 197 | + {0.8823, 0.9150, 0.3829, 0.9593, |
| 198 | + 0.3904, 0.6009, 0.2566, 0.7936, |
| 199 | + 0.9408, 0.1332, 0.9346, 0.5936, |
| 200 | + 0.8694, 0.5677, 0.7411, 0.4294}); |
| 201 | + |
| 202 | + // K: [B=1, kv_len=3, H_kv=2, D=4] |
| 203 | + executorch::aten::Tensor k = tf.make( |
| 204 | + {1, 3, 2, 4}, |
| 205 | + {0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317, |
| 206 | + 0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753, |
| 207 | + 0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423}); |
| 208 | + |
| 209 | + // V: [B=1, kv_len=3, H_kv=2, D=4] |
| 210 | + executorch::aten::Tensor v = tf.make( |
| 211 | + {1, 3, 2, 4}, |
| 212 | + {0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895, |
| 213 | + 0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071, |
| 214 | + 0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278}); |
| 215 | + |
| 216 | + int64_t start_pos = 2; |
| 217 | + int num_valid = 3; |
| 218 | + |
| 219 | + std::vector<float> ref(16, 0.0f); |
| 220 | + compute_reference_sdpa( |
| 221 | + q.const_data_ptr<float>(), 1, 1, 4, 4, |
| 222 | + k.const_data_ptr<float>(), 3, 2, |
| 223 | + v.const_data_ptr<float>(), |
| 224 | + ref.data(), false, start_pos, num_valid); |
| 225 | + |
| 226 | + executorch::aten::Tensor expected = tf.make({1, 1, 4, 4}, ref); |
| 227 | + executorch::aten::Tensor out = tf.zeros({1, 1, 4, 4}); |
| 228 | + call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out); |
| 229 | + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-4, 1e-4); |
| 230 | +} |
| 231 | + |
| 232 | +// For seq_len=1, causal mask doesn't restrict any positions |
| 233 | +// (all start_pos+1 entries are visible), so result must match non-causal. |
| 234 | +TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) { |
| 235 | + TensorFactory<executorch::aten::ScalarType::Float> tf; |
| 236 | + |
| 237 | + executorch::aten::Tensor q = tf.make( |
| 238 | + {1, 1, 2, 4}, |
| 239 | + {0.8823, 0.9150, 0.3829, 0.9593, |
| 240 | + 0.3904, 0.6009, 0.2566, 0.7936}); |
| 241 | + |
| 242 | + executorch::aten::Tensor k = tf.make( |
| 243 | + {1, 4, 2, 4}, |
| 244 | + {0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317, |
| 245 | + 0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753, |
| 246 | + 0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423, |
| 247 | + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000}); |
| 248 | + |
| 249 | + executorch::aten::Tensor v = tf.make( |
| 250 | + {1, 4, 2, 4}, |
| 251 | + {0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895, |
| 252 | + 0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071, |
| 253 | + 0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278, |
| 254 | + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000}); |
| 255 | + |
| 256 | + int64_t start_pos = 2; |
| 257 | + |
| 258 | + executorch::aten::Tensor out_nc = tf.zeros({1, 1, 2, 4}); |
| 259 | + call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out_nc); |
| 260 | + |
| 261 | + executorch::aten::Tensor out_c = tf.zeros({1, 1, 2, 4}); |
| 262 | + call_custom_sdpa(q, k, v, start_pos, {}, 0.0, true, {}, out_c); |
| 263 | + |
| 264 | + EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6); |
| 265 | +} |
0 commit comments