diff --git a/extension/llm/custom_ops/op_custom_sdpa_test.cpp b/extension/llm/custom_ops/op_custom_sdpa_test.cpp new file mode 100644 index 00000000000..92472c7f64a --- /dev/null +++ b/extension/llm/custom_ops/op_custom_sdpa_test.cpp @@ -0,0 +1,292 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Tests for the unfused SDPA code path (cpu_sdpa) dispatched when +// seq_len == 1 and inputs are non-quantized (the decode fast-path). +// These call custom_sdpa_out directly, not through sdpa_with_kv_cache. + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +using namespace ::testing; +using executorch::runtime::testing::TensorFactory; + +namespace { + +// Helper: call custom_sdpa_out. Inputs use [B, S, H, D] layout. +executorch::aten::Tensor call_custom_sdpa( + const executorch::aten::Tensor& q, + const executorch::aten::Tensor& k, + const executorch::aten::Tensor& v, + int64_t start_pos, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + executorch::aten::Tensor& out) { + executorch::runtime::KernelRuntimeContext ctx{}; + return torch::executor::native::custom_sdpa_out( + ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, out); +} + +/** + * Naive reference SDPA for [B, S, H, D] layout. + * Element [b,s,h,d] is at index b*S*H*D + s*H*D + h*D + d. + * Only first num_valid_keys KV entries are used. + */ +void compute_reference_sdpa( + const float* q_data, + int B, + int qS, + int qH, + int D, + const float* k_data, + int kvS, + int kvH, + const float* v_data, + float* out_data, + bool is_causal, + int64_t start_pos, + int num_valid_keys) { + float scale = 1.0f / std::sqrt(static_cast(D)); + int num_reps = qH / kvH; + + for (int b = 0; b < B; b++) { + for (int h = 0; h < qH; h++) { + int kv_h = h / num_reps; + for (int qs = 0; qs < qS; qs++) { + // scores = Q @ K^T * scale + std::vector scores(num_valid_keys); + for (int kvs = 0; kvs < num_valid_keys; kvs++) { + float dot = 0; + for (int d = 0; d < D; d++) { + float qv = q_data[b * qS * qH * D + qs * qH * D + h * D + d]; + float kv = k_data[b * kvS * kvH * D + kvs * kvH * D + kv_h * D + d]; + dot += qv * kv; + } + scores[kvs] = dot * scale; + } + + // Causal mask + if (is_causal) { + int64_t valid = std::min( + start_pos + qs + 1, static_cast(num_valid_keys)); + for (int64_t j = valid; j < num_valid_keys; j++) { + scores[j] = -std::numeric_limits::infinity(); + } + } + + // Softmax + float max_val = *std::max_element(scores.begin(), scores.end()); + float sum = 0; + for (auto& s : scores) { + s = std::exp(s - max_val); + sum += s; + } + if (sum > 0) { + for (auto& s : scores) { + s /= sum; + } + } + + // output = scores @ V + for (int d = 0; d < D; d++) { + float val = 0; + for (int kvs = 0; kvs < num_valid_keys; kvs++) { + float vv = v_data[b * kvS * kvH * D + kvs * kvH * D + kv_h * D + d]; + val += scores[kvs] * vv; + } + out_data[b * qS * qH * D + qs * qH * D + h * D + d] = val; + } + } + } + } +} + +} // namespace + +// With a single KV entry (start_pos=0), output must equal V[0]. +TEST(OpCustomSdpaTest, DecodeSingleKV) { + TensorFactory tf; + + executorch::aten::Tensor q = tf.make( + {1, 1, 2, 4}, + {0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936}); + + executorch::aten::Tensor k = tf.make( + {1, 1, 2, 4}, + {0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317}); + + executorch::aten::Tensor v = tf.make( + {1, 1, 2, 4}, + {0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895}); + + // softmax of a single score is always 1.0, so output == V + executorch::aten::Tensor expected = tf.make( + {1, 1, 2, 4}, + {0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895}); + + executorch::aten::Tensor out = tf.zeros({1, 1, 2, 4}); + call_custom_sdpa(q, k, v, /*start_pos=*/0, {}, 0.0, false, {}, out); + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-6, 1e-6); +} + +// Decode with 3 valid KV entries, verified against reference computation. +TEST(OpCustomSdpaTest, DecodeNonCausal) { + TensorFactory tf; + + // Q: [B=1, S=1, H=2, D=4] + executorch::aten::Tensor q = tf.make( + {1, 1, 2, 4}, + {0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936}); + + // K, V: [B=1, kv_len=4, H=2, D=4], first 3 entries valid + executorch::aten::Tensor k = tf.make( + {1, 4, 2, 4}, + {0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317, + 0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753, + 0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000}); + + executorch::aten::Tensor v = tf.make( + {1, 4, 2, 4}, + {0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895, + 0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071, + 0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000}); + + int64_t start_pos = 2; + int num_valid = 3; + + std::vector ref(8, 0.0f); + compute_reference_sdpa( + q.const_data_ptr(), + 1, + 1, + 2, + 4, + k.const_data_ptr(), + 4, + 2, + v.const_data_ptr(), + ref.data(), + false, + start_pos, + num_valid); + + executorch::aten::Tensor expected = tf.make({1, 1, 2, 4}, ref); + executorch::aten::Tensor out = tf.zeros({1, 1, 2, 4}); + call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out); + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-4, 1e-4); +} + +// GQA: 4 query heads sharing 2 KV heads. +TEST(OpCustomSdpaTest, DecodeGQA) { + TensorFactory tf; + + // Q: [B=1, S=1, H_q=4, D=4] + executorch::aten::Tensor q = tf.make( + {1, 1, 4, 4}, + {0.8823, + 0.9150, + 0.3829, + 0.9593, + 0.3904, + 0.6009, + 0.2566, + 0.7936, + 0.9408, + 0.1332, + 0.9346, + 0.5936, + 0.8694, + 0.5677, + 0.7411, + 0.4294}); + + // K: [B=1, kv_len=3, H_kv=2, D=4] + executorch::aten::Tensor k = + tf.make({1, 3, 2, 4}, {0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, + 0.2969, 0.8317, 0.1053, 0.2695, 0.3588, 0.1994, + 0.5472, 0.0062, 0.9516, 0.0753, 0.8860, 0.5832, + 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423}); + + // V: [B=1, kv_len=3, H_kv=2, D=4] + executorch::aten::Tensor v = + tf.make({1, 3, 2, 4}, {0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, + 0.7886, 0.5895, 0.7539, 0.1952, 0.0050, 0.3068, + 0.1165, 0.9103, 0.6440, 0.7071, 0.6581, 0.4913, + 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278}); + + int64_t start_pos = 2; + int num_valid = 3; + + std::vector ref(16, 0.0f); + compute_reference_sdpa( + q.const_data_ptr(), + 1, + 1, + 4, + 4, + k.const_data_ptr(), + 3, + 2, + v.const_data_ptr(), + ref.data(), + false, + start_pos, + num_valid); + + executorch::aten::Tensor expected = tf.make({1, 1, 4, 4}, ref); + executorch::aten::Tensor out = tf.zeros({1, 1, 4, 4}); + call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out); + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-4, 1e-4); +} + +// For seq_len=1, causal mask doesn't restrict any positions +// (all start_pos+1 entries are visible), so result must match non-causal. +TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) { + TensorFactory tf; + + executorch::aten::Tensor q = tf.make( + {1, 1, 2, 4}, + {0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936}); + + executorch::aten::Tensor k = tf.make( + {1, 4, 2, 4}, + {0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317, + 0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753, + 0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000}); + + executorch::aten::Tensor v = tf.make( + {1, 4, 2, 4}, + {0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895, + 0.7539, 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071, + 0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542, 0.3278, + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000}); + + int64_t start_pos = 2; + + executorch::aten::Tensor out_nc = tf.zeros({1, 1, 2, 4}); + call_custom_sdpa(q, k, v, start_pos, {}, 0.0, false, {}, out_nc); + + executorch::aten::Tensor out_c = tf.zeros({1, 1, 2, 4}); + call_custom_sdpa(q, k, v, start_pos, {}, 0.0, true, {}, out_c); + + EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6); +} diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 72bddce7b5b..17759fa6dd5 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -412,75 +412,88 @@ Tensor& custom_sdpa_out_impl( InvalidArgument, output); - // TODO(task): replace the template param selection logic - // with whatever apprpriately makes more sense for - ET_SWITCH_FLOAT_TYPES( - output.scalar_type(), ctx, "flash_attention", CTYPE, [&] { - // TODO we need to re-evaluate this for ARM CPUs - // And there can be many so instead of templatizing - // we might consider another appraoch - if (seq_len >= 768) { - sdpa::impl::cpu_flash_attention( - ctx, - output, - q, - k, - v, - dropout_p, - is_causal, - attn_mask, - scale, - q_zero_points, // q_zero_points - q_scales, // q_scales - k_zero_points, // k_zero_points - k_scales, // k_scales - v_zero_points, // v_zero_points - v_scales, // v_scales - seq_dim, /* seq_dim */ - start_pos, - num_keys_for_causal_attention); - } else if (seq_len >= 192) { - sdpa::impl::cpu_flash_attention( - ctx, - output, - q, - k, - v, - dropout_p, - is_causal, - attn_mask, - scale, - q_zero_points, // q_zero_points - q_scales, // q_scales - k_zero_points, // k_zero_points - k_scales, // k_scales - v_zero_points, // v_zero_points - v_scales, // v_scales - seq_dim, /* seq_dim */ - start_pos, - num_keys_for_causal_attention); - } else { - sdpa::impl::cpu_flash_attention( - ctx, - output, - q, - k, - v, - dropout_p, - is_causal, - attn_mask, - scale, - q_zero_points, // q_zero_points - q_scales, // q_scales - k_zero_points, // k_zero_points - k_scales, // k_scales - v_zero_points, // v_zero_points - v_scales, // v_scales - seq_dim, /* seq_dim */ - start_pos, - num_keys_for_causal_attention); - } - }); + bool use_unfused_sdpa = q.scalar_type() != ScalarType::Char && seq_len == 1; + if (use_unfused_sdpa) { + ET_SWITCH_FLOAT_TYPES(output.scalar_type(), ctx, "sdpa", CTYPE, [&] { + sdpa::impl::cpu_sdpa( + ctx, + output, + q, + k, + v, + is_causal, + attn_mask, + scale, + seq_dim, + start_pos, + num_keys_for_causal_attention); + }); + } else { + ET_SWITCH_FLOAT_TYPES( + output.scalar_type(), ctx, "flash_attention", CTYPE, [&] { + if (seq_len >= 768) { + sdpa::impl::cpu_flash_attention( + ctx, + output, + q, + k, + v, + dropout_p, + is_causal, + attn_mask, + scale, + q_zero_points, + q_scales, + k_zero_points, + k_scales, + v_zero_points, + v_scales, + seq_dim, + start_pos, + num_keys_for_causal_attention); + } else if (seq_len >= 192) { + sdpa::impl::cpu_flash_attention( + ctx, + output, + q, + k, + v, + dropout_p, + is_causal, + attn_mask, + scale, + q_zero_points, + q_scales, + k_zero_points, + k_scales, + v_zero_points, + v_scales, + seq_dim, + start_pos, + num_keys_for_causal_attention); + } else { + sdpa::impl::cpu_flash_attention( + ctx, + output, + q, + k, + v, + dropout_p, + is_causal, + attn_mask, + scale, + q_zero_points, + q_scales, + k_zero_points, + k_scales, + v_zero_points, + v_scales, + seq_dim, + start_pos, + num_keys_for_causal_attention); + } + }); + } return output; } diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 73c5ccf707f..467af1c89f4 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -1097,6 +1097,211 @@ void cpu_flash_attention( torch::executor::parallel_for( 0, batchSize * num_head * qSlice, 1, compute_lambda); } + +/** + * @brief Non-flash (unfused) SDPA implementation using standard GEMM. + * + * Single full GEMM per head for Q@K^T and scores@V, with standard 3-pass + * softmax (no tiling). Useful as a simpler baseline and for cases where + * flash attention is not optimal (e.g. very short sequences). + * + * @tparam scalar_t Data type for computation + * @param seq_dim Which dimension is sequence dimension (SeqDim::ONE or TWO) + * Used for all of Q, K, V, and output stride extraction. + * @param start_pos Starting position for causal masking + * @param num_keys_for_causal_attention Number of keys for causal attention + */ +template +void cpu_sdpa( + RuntimeContext& ctx, + Tensor& output, + const Tensor& query, + const Tensor& key, + const Tensor& value, + bool is_causal, + const optional& attn_mask, + const optional& scale, + const SeqDim seq_dim, + const int64_t start_pos, + const int64_t num_keys_for_causal_attention) { + using accum_t = scalar_t; + using Vec = vec::Vectorized; + accum_t scaling_factor = static_cast(calculate_scale(query, scale)); + + int64_t batchSize = query.size(0); + int64_t num_head = query.size(1); + int64_t qSize = query.size(2); + int64_t headSize = query.size(3); + int64_t kvSize = value.size(2); + int64_t num_heads_kv = key.size(1); + + if (seq_dim == SeqDim::ONE) { + num_head = query.size(2); + num_heads_kv = key.size(2); + qSize = query.size(1); + kvSize = value.size(1); + } + + if (num_keys_for_causal_attention > 0) { + ET_CHECK_MSG( + num_keys_for_causal_attention <= kvSize, + "num_keys_for_causal_attention must be <= kvSize"); + kvSize = num_keys_for_causal_attention; + } + + ET_CHECK_MSG( + num_heads_kv <= num_head, + "cpu_sdpa does not support num kv heads > num query heads"); + ET_CHECK_MSG( + num_head % num_heads_kv == 0, + "cpu_sdpa: num query heads must be divisible by num kv heads"); + int64_t num_reps = num_head / num_heads_kv; + + bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); + + // Extract strides, swapping seq/head dims based on seq_dim + auto q_strides = query.strides(); + int64_t qStrideB = q_strides[0]; + int64_t qStrideH = (seq_dim == SeqDim::ONE) ? q_strides[2] : q_strides[1]; + int64_t qStrideM = (seq_dim == SeqDim::ONE) ? q_strides[1] : q_strides[2]; + + auto k_strides = key.strides(); + int64_t kStrideB = k_strides[0]; + int64_t kStrideH = (seq_dim == SeqDim::ONE) ? k_strides[2] : k_strides[1]; + int64_t kStrideN = (seq_dim == SeqDim::ONE) ? k_strides[1] : k_strides[2]; + + auto v_strides = value.strides(); + int64_t vStrideB = v_strides[0]; + int64_t vStrideH = (seq_dim == SeqDim::ONE) ? v_strides[2] : v_strides[1]; + int64_t vStrideN = (seq_dim == SeqDim::ONE) ? v_strides[1] : v_strides[2]; + + auto o_strides = output.strides(); + int64_t oStrideB = o_strides[0]; + int64_t oStrideH = (seq_dim == SeqDim::ONE) ? o_strides[2] : o_strides[1]; + int64_t oStrideM = (seq_dim == SeqDim::ONE) ? o_strides[1] : o_strides[2]; + + int64_t mStrideM = 0; + if (has_attn_mask) { + auto m_strides = attn_mask.value().strides(); + mStrideM = m_strides[0]; + } + + // Allocate per-thread scores buffer: [qSize, kvSize] per (batch, head) +#ifdef ET_USE_THREADPOOL + int64_t num_thread = + ::executorch::extension::threadpool::get_threadpool()->get_thread_count(); +#else + int64_t num_thread = 1; +#endif + + int64_t scores_per_thread = qSize * kvSize; + int64_t size_bytes = scores_per_thread * num_thread * sizeof(accum_t); + std::unique_ptr allocated_buf; + void* buf; + Result scratch = ctx.allocate_temp(size_bytes, 64); + if (!scratch.ok()) { + allocated_buf = std::make_unique(size_bytes); + buf = allocated_buf.get(); + } else { + buf = scratch.get(); + } + accum_t* buf_data = reinterpret_cast(buf); + + const scalar_t* q_data = query.const_data_ptr(); + const scalar_t* k_data = key.const_data_ptr(); + const scalar_t* v_data = value.const_data_ptr(); + const accum_t* mask_data = + has_attn_mask ? attn_mask.value().const_data_ptr() : nullptr; + scalar_t* out_data = output.mutable_data_ptr(); + + auto compute_lambda = [&](int64_t begin, int64_t end) { + int64_t ompIdx = torch::executor::get_thread_num(); + accum_t* scores = buf_data + ompIdx * scores_per_thread; + + for (int64_t idx = begin; idx < end; ++idx) { + int64_t b = idx / num_head; + int64_t h = idx % num_head; + int64_t kv_h = h / num_reps; + + // Pointer to Q[b, h, :, :] and K[b, kv_h, :, :] with appropriate strides + const scalar_t* q_ptr = q_data + b * qStrideB + h * qStrideH; + const scalar_t* k_ptr = k_data + b * kStrideB + kv_h * kStrideH; + const scalar_t* v_ptr = v_data + b * vStrideB + kv_h * vStrideH; + scalar_t* o_ptr = out_data + b * oStrideB + h * oStrideH; + + // GEMM 1: scores[qSize, kvSize] = scaling_factor * Q[qSize, D] @ K^T[D, + // kvSize] + ::executorch::cpublas::gemm( + ::executorch::cpublas::TransposeType::Transpose, + ::executorch::cpublas::TransposeType::NoTranspose, + kvSize, + qSize, + headSize, + scaling_factor, + k_ptr, + kStrideN, + q_ptr, + qStrideM, + static_cast(0), + scores, + kvSize); + + // Causal mask + attention mask + softmax per query row + for (int64_t qi = 0; qi < qSize; ++qi) { + accum_t* row = scores + qi * kvSize; + + // Apply attention mask if present + if (has_attn_mask) { + const accum_t* mask_row = mask_data + qi * mStrideM; + for (int64_t j = 0; j < kvSize; ++j) { + row[j] += mask_row[j]; + } + } + + // Apply causal mask + if (is_causal) { + int64_t valid = std::min(start_pos + qi + 1, kvSize); + for (int64_t j = valid; j < kvSize; ++j) { + row[j] = -std::numeric_limits::infinity(); + } + } + + // Softmax: find max, compute exp, normalize + accum_t max_val = vec::reduce_all( + [](Vec& x, Vec& y) { return vec::maximum(x, y); }, row, kvSize); + + if (max_val == -std::numeric_limits::infinity()) { + fill_stub(row, static_cast(0), kvSize); + } else { + accum_t sum_val = max_val; + const int kvSizeInt = static_cast(kvSize); + _exp_reduce_sum_fusion_kernel(row, kvSizeInt, row, sum_val); + accum_t inv_sum = static_cast(1) / sum_val; + vec::map( + [inv_sum](Vec x) { return x * Vec(inv_sum); }, row, row, kvSize); + } + } + + // GEMM 2: output[qSize, D] = scores[qSize, kvSize] @ V[kvSize, D] + ::executorch::cpublas::gemm( + ::executorch::cpublas::TransposeType::NoTranspose, + ::executorch::cpublas::TransposeType::NoTranspose, + headSize, + qSize, + kvSize, + static_cast(1), + v_ptr, + vStrideN, + scores, + kvSize, + static_cast(0), + o_ptr, + oStrideM); + } + }; + torch::executor::parallel_for(0, batchSize * num_head, 1, compute_lambda); +} + } // namespace sdpa::impl } // namespace native } // namespace executor diff --git a/extension/llm/custom_ops/targets.bzl b/extension/llm/custom_ops/targets.bzl index 6746d7ab877..61297e65dae 100644 --- a/extension/llm/custom_ops/targets.bzl +++ b/extension/llm/custom_ops/targets.bzl @@ -125,6 +125,20 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "op_custom_sdpa_test", + srcs = [ + "op_custom_sdpa_test.cpp", + ], + visibility = ["//executorch/..."], + deps = [ + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "//executorch/kernels/test:test_util", + ":custom_ops", + ], + ) + ## For preprocess runtime.python_library( name = "preprocess_custom_ops_py",