Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -420,15 +420,26 @@
if (use_unfused_sdpa) {
ET_SWITCH_FLOAT_TYPES(
output.scalar_type(), ctx, "sdpa", CTYPE, [&] {
sdpa::impl::cpu_sdpa<CTYPE>(
ctx, output, q, k, v, is_causal, attn_mask, scale,
seq_dim,
start_pos, num_keys_for_causal_attention,
q_zero_points, q_scales,
k_zero_points, k_scales,
v_zero_points, v_scales);
});
sdpa::impl::cpu_sdpa<CTYPE>(
ctx,
output,
q,
k,
v,
is_causal,
attn_mask,
scale,
q_seq_dim,
k_seq_dim,
v_seq_dim,
start_pos,
num_keys_for_causal_attention,
q_zero_points, q_scales,
k_zero_points, k_scales,
v_zero_points, v_scales);
});
} else {
// Flash attention path (default) with tile-size selection
ET_SWITCH_FLOAT_TYPES(
output.scalar_type(), ctx, "flash_attention", CTYPE, [&] {
if (seq_len >= 768) {
Expand Down
138 changes: 77 additions & 61 deletions extension/llm/custom_ops/op_sdpa_impl.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -1068,17 +1068,23 @@
}

/**
* @brief Non-flash (unfused) SDPA implementation using standard GEMM.
* @brief Non-flash (unfused) SDPA: full Q@K^T, softmax, then scores@V.
*
* Single full GEMM per head for Q@K^T and scores@V, with standard 3-pass
* softmax (no tiling). Useful as a simpler baseline and for cases where
* flash attention is not optimal (e.g. very short sequences).
* Avoids within-head tiling overhead of flash attention, which hurts when Q
* has very few rows (e.g. decode with SeqLen=1). Float-only; no quantized
* input support.
*
* @tparam scalar_t Data type for computation
* @param seq_dim Which dimension is sequence dimension (SeqDim::ONE or TWO)
* Used for all of Q, K, V, and output stride extraction.
* @param start_pos Starting position for causal masking
* @param num_keys_for_causal_attention Number of keys for causal attention
* @tparam scalar_t The data type for computation (float or double)
* @param output Output tensor [B, H, S_q, D] (or transposed SeqDim layout)
* @param query Query tensor [B, H, S_q, D]
* @param key Key tensor [B, H_kv, S_kv, D]
* @param value Value tensor [B, H_kv, S_kv, D]
* @param is_causal Whether to apply causal (lower-triangular) masking
* @param attn_mask Optional 2-D float attention mask [S_q, S_kv]
* @param scale Optional scaling factor (default 1/sqrt(D))
* @param q_seq_dim / k_seq_dim / v_seq_dim Sequence dimension layout
* @param start_pos Starting position for causal masking during generation
* @param num_keys_for_causal_attention Number of keys to attend to (-1=all)
*/
template <typename scalar_t>
void cpu_sdpa(
Expand All @@ -1090,7 +1096,9 @@
bool is_causal,
const optional<Tensor>& attn_mask,
const optional<double>& scale,
const SeqDim seq_dim,
const SeqDim q_seq_dim,
const SeqDim k_seq_dim,
const SeqDim v_seq_dim,
const int64_t start_pos,
const int64_t num_keys_for_causal_attention,
const optional<Tensor>& q_zero_points = nullopt,
Expand All @@ -1099,23 +1107,25 @@
const optional<Tensor>& k_scales = nullopt,
const optional<Tensor>& v_zero_points = nullopt,
const optional<Tensor>& v_scales = nullopt) {
ET_CHECK_MSG(
query.scalar_type() != ScalarType::Char,
"Non-flash SDPA does not support quantized (int8) inputs");

using accum_t = scalar_t;
using Vec = vec::Vectorized<accum_t>;
accum_t scaling_factor = static_cast<accum_t>(calculate_scale(query, scale));

// Dimension indices: SeqDim::TWO => [B,H,S,D], SeqDim::ONE => [B,S,H,D]
int64_t q_head_idx = 3 - static_cast<int64_t>(q_seq_dim);
int64_t k_head_idx = 3 - static_cast<int64_t>(k_seq_dim);
int64_t v_head_idx = 3 - static_cast<int64_t>(v_seq_dim);

int64_t batchSize = query.size(0);
int64_t num_head = query.size(1);
int64_t qSize = query.size(2);
int64_t num_head = query.size(q_head_idx);
int64_t qSize = query.size(static_cast<int64_t>(q_seq_dim));
int64_t headSize = query.size(3);
int64_t kvSize = value.size(2);
int64_t num_heads_kv = key.size(1);

if (seq_dim == SeqDim::ONE) {
num_head = query.size(2);
num_heads_kv = key.size(2);
qSize = query.size(1);
kvSize = value.size(1);
}
int64_t kvSize = key.size(static_cast<int64_t>(k_seq_dim));
int64_t num_heads_kv = key.size(k_head_idx);

if (num_keys_for_causal_attention > 0) {
ET_CHECK_MSG(
Expand All @@ -1126,40 +1136,43 @@

ET_CHECK_MSG(
num_heads_kv <= num_head,
"cpu_sdpa does not support num kv heads > num query heads");
"cpu_sdpa: num kv heads > num query heads not supported");
ET_CHECK_MSG(
num_head % num_heads_kv == 0,
"cpu_sdpa: num query heads must be divisible by num kv heads");
int64_t num_reps = num_head / num_heads_kv;

bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();
bool is_quantized_sdpa = query.scalar_type() == ScalarType::Char;
if (has_attn_mask) {
ET_CHECK_MSG(attn_mask.value().dim() == 2, "attn_mask must be 2D");
}

// Extract strides, swapping seq/head dims based on seq_dim
auto q_strides = query.strides();
int64_t qStrideB = q_strides[0];
int64_t qStrideH = (seq_dim == SeqDim::ONE) ? q_strides[2] : q_strides[1];
int64_t qStrideM = (seq_dim == SeqDim::ONE) ? q_strides[1] : q_strides[2];
// Extract strides (same pattern as cpu_flash_attention)
auto strides = query.strides();
int64_t qStrideB = strides[0];
int64_t qStrideH = strides[q_head_idx];
int64_t qStrideM = strides[static_cast<int64_t>(q_seq_dim)];

auto k_strides = key.strides();
int64_t kStrideB = k_strides[0];
int64_t kStrideH = (seq_dim == SeqDim::ONE) ? k_strides[2] : k_strides[1];
int64_t kStrideN = (seq_dim == SeqDim::ONE) ? k_strides[1] : k_strides[2];
strides = key.strides();
int64_t kStrideB = strides[0];
int64_t kStrideH = strides[k_head_idx];
int64_t kStrideN = strides[static_cast<int64_t>(k_seq_dim)];

auto v_strides = value.strides();
int64_t vStrideB = v_strides[0];
int64_t vStrideH = (seq_dim == SeqDim::ONE) ? v_strides[2] : v_strides[1];
int64_t vStrideN = (seq_dim == SeqDim::ONE) ? v_strides[1] : v_strides[2];
strides = value.strides();
int64_t vStrideB = strides[0];
int64_t vStrideH = strides[v_head_idx];
int64_t vStrideN = strides[static_cast<int64_t>(v_seq_dim)];

auto o_strides = output.strides();
int64_t oStrideB = o_strides[0];
int64_t oStrideH = (seq_dim == SeqDim::ONE) ? o_strides[2] : o_strides[1];
int64_t oStrideM = (seq_dim == SeqDim::ONE) ? o_strides[1] : o_strides[2];
strides = output.strides();
int64_t oStrideB = strides[0];
int64_t oStrideH = strides[q_head_idx];
int64_t oStrideM = strides[static_cast<int64_t>(q_seq_dim)];

int64_t mStrideM = 0;
if (has_attn_mask) {
auto m_strides = attn_mask.value().strides();
mStrideM = m_strides[0];
strides = attn_mask.value().strides();
mStrideM = strides[0];
}

int64_t q_quant_params_StrideB = 0;
Expand Down Expand Up @@ -1189,26 +1202,26 @@
v_quant_params_StrideN = (seq_dim == SeqDim::ONE) ? v_qp_strides[1] : v_qp_strides[2];
}

// Allocate per-thread scores buffer: [qSize, kvSize] per (batch, head)
// Thread count for per-thread scratch allocation
#ifdef ET_USE_THREADPOOL
int64_t num_thread =
::executorch::extension::threadpool::get_threadpool()->get_thread_count();
#else
int64_t num_thread = 1;
#endif

int64_t scores_per_thread = qSize * kvSize;
int64_t size_bytes = scores_per_thread * num_thread * sizeof(accum_t);
// Allocate scores buffer: one [qSize x kvSize] matrix per thread
int64_t size_per_thread = qSize * kvSize;
int64_t size_bytes = size_per_thread * num_thread * sizeof(accum_t);
std::unique_ptr<char[]> allocated_buf;
void* buf;
accum_t* scores_buf;
Result<void*> scratch = ctx.allocate_temp(size_bytes, 64);
if (!scratch.ok()) {
allocated_buf = std::make_unique<char[]>(size_bytes);
buf = allocated_buf.get();
scores_buf = reinterpret_cast<accum_t*>(allocated_buf.get());
} else {
buf = scratch.get();
scores_buf = reinterpret_cast<accum_t*>(scratch.get());
}
accum_t* buf_data = reinterpret_cast<accum_t*>(buf);

// Allocate dequantization buffer for V (used by _qk_at_v_gemm when m > 4)
int64_t size_per_thread_qdq_vec = kvSize * headSize;
Expand All @@ -1228,24 +1241,26 @@
}
}

// Data pointers
const scalar_t* q_data = query.const_data_ptr<scalar_t>();
const scalar_t* k_data = key.const_data_ptr<scalar_t>();
const scalar_t* v_data = value.const_data_ptr<scalar_t>();
const accum_t* mask_data =
has_attn_mask ? attn_mask.value().const_data_ptr<accum_t>() : nullptr;
scalar_t* out_data = output.mutable_data_ptr<scalar_t>();

// One work-unit per (batch, head) — simpler than flash's (batch, head, block)
auto compute_lambda = [&](int64_t begin, int64_t end) {
int64_t ompIdx = torch::executor::get_thread_num();
accum_t* scores = buf_data + ompIdx * scores_per_thread;
accum_t* scores = scores_buf + ompIdx * size_per_thread;
accum_t* buf_qdq_ptr = is_quantized_sdpa
? scratch_for_quant_dequant + ompIdx * size_per_thread_qdq_vec
: nullptr;

for (int64_t idx = begin; idx < end; ++idx) {
int64_t b = idx / num_head;
int64_t h = idx % num_head;
int64_t kv_h = h / num_reps;
for (int64_t z = begin; z < end; z++) {
int64_t i = z / num_head; // batch index
int64_t j = z % num_head; // head index
int64_t j_kv = j / num_reps; // GQA: map query head to kv head

const void* q_ptr;
const void* k_ptr;
Expand All @@ -1257,21 +1272,21 @@
const int8_t* k_zp_ptr = nullptr;
const int8_t* v_zp_ptr = nullptr;

int64_t q_offset = b * qStrideB + h * qStrideH;
int64_t k_offset = b * kStrideB + kv_h * kStrideH;
int64_t v_offset = b * vStrideB + kv_h * vStrideH;
int64_t q_offset = i * qStrideB + j * qStrideH;
int64_t k_offset = i * kStrideB + j_kv * kStrideH;
int64_t v_offset = i * vStrideB + j_kv * vStrideH;

if (is_quantized_sdpa) {
q_ptr = reinterpret_cast<const int8_t*>(q_data) + q_offset;
k_ptr = reinterpret_cast<const int8_t*>(k_data) + k_offset;
v_ptr = reinterpret_cast<const int8_t*>(v_data) + v_offset;

int64_t q_qp_offset =
b * q_quant_params_StrideB + h * q_quant_params_StrideH;
i * q_quant_params_StrideB + j * q_quant_params_StrideH;
int64_t k_qp_offset =
b * k_quant_params_StrideB + kv_h * k_quant_params_StrideH;
i * k_quant_params_StrideB + j_kv * k_quant_params_StrideH;
int64_t v_qp_offset =
b * v_quant_params_StrideB + kv_h * v_quant_params_StrideH;
i * v_quant_params_StrideB + j_kv * v_quant_params_StrideH;

q_scales_ptr =
q_scales.value().const_data_ptr<float>() + q_qp_offset;
Expand All @@ -1290,7 +1305,7 @@
k_ptr = k_data + k_offset;
v_ptr = v_data + v_offset;
}
scalar_t* o_ptr = out_data + b * oStrideB + h * oStrideH;
scalar_t* out_ptr = out_data + i * oStrideB + j * oStrideH;

// GEMM 1: scores[qSize, kvSize] = Q[qSize, D] @ K^T[D, kvSize]
MaybeQuantizedMatrixData q_matrix(
Expand Down Expand Up @@ -1354,10 +1369,11 @@
qSize, headSize, kvSize,
scores, kvSize,
v_matrix, vStrideN,
o_ptr, oStrideM,
out_ptr, oStrideM,
static_cast<accum_t>(0), buf_qdq_ptr);
}
};

torch::executor::parallel_for(
0, batchSize * num_head, 1, compute_lambda);
}
Expand Down
Loading