diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 7f8f760ed6b..62f594a8e6b 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -420,15 +420,26 @@ Tensor& custom_sdpa_out_impl( if (use_unfused_sdpa) { ET_SWITCH_FLOAT_TYPES( output.scalar_type(), ctx, "sdpa", CTYPE, [&] { - sdpa::impl::cpu_sdpa( - ctx, output, q, k, v, is_causal, attn_mask, scale, - seq_dim, - start_pos, num_keys_for_causal_attention, - q_zero_points, q_scales, - k_zero_points, k_scales, - v_zero_points, v_scales); - }); + sdpa::impl::cpu_sdpa( + ctx, + output, + q, + k, + v, + is_causal, + attn_mask, + scale, + q_seq_dim, + k_seq_dim, + v_seq_dim, + start_pos, + num_keys_for_causal_attention, + q_zero_points, q_scales, + k_zero_points, k_scales, + v_zero_points, v_scales); + }); } else { + // Flash attention path (default) with tile-size selection ET_SWITCH_FLOAT_TYPES( output.scalar_type(), ctx, "flash_attention", CTYPE, [&] { if (seq_len >= 768) { diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 47ba3632d26..50069ef3888 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -1068,17 +1068,23 @@ void cpu_flash_attention( } /** - * @brief Non-flash (unfused) SDPA implementation using standard GEMM. + * @brief Non-flash (unfused) SDPA: full Q@K^T, softmax, then scores@V. * - * Single full GEMM per head for Q@K^T and scores@V, with standard 3-pass - * softmax (no tiling). Useful as a simpler baseline and for cases where - * flash attention is not optimal (e.g. very short sequences). + * Avoids within-head tiling overhead of flash attention, which hurts when Q + * has very few rows (e.g. decode with SeqLen=1). Float-only; no quantized + * input support. * - * @tparam scalar_t Data type for computation - * @param seq_dim Which dimension is sequence dimension (SeqDim::ONE or TWO) - * Used for all of Q, K, V, and output stride extraction. - * @param start_pos Starting position for causal masking - * @param num_keys_for_causal_attention Number of keys for causal attention + * @tparam scalar_t The data type for computation (float or double) + * @param output Output tensor [B, H, S_q, D] (or transposed SeqDim layout) + * @param query Query tensor [B, H, S_q, D] + * @param key Key tensor [B, H_kv, S_kv, D] + * @param value Value tensor [B, H_kv, S_kv, D] + * @param is_causal Whether to apply causal (lower-triangular) masking + * @param attn_mask Optional 2-D float attention mask [S_q, S_kv] + * @param scale Optional scaling factor (default 1/sqrt(D)) + * @param q_seq_dim / k_seq_dim / v_seq_dim Sequence dimension layout + * @param start_pos Starting position for causal masking during generation + * @param num_keys_for_causal_attention Number of keys to attend to (-1=all) */ template void cpu_sdpa( @@ -1090,7 +1096,9 @@ void cpu_sdpa( bool is_causal, const optional& attn_mask, const optional& scale, - const SeqDim seq_dim, + const SeqDim q_seq_dim, + const SeqDim k_seq_dim, + const SeqDim v_seq_dim, const int64_t start_pos, const int64_t num_keys_for_causal_attention, const optional& q_zero_points = nullopt, @@ -1099,23 +1107,25 @@ void cpu_sdpa( const optional& k_scales = nullopt, const optional& v_zero_points = nullopt, const optional& v_scales = nullopt) { + ET_CHECK_MSG( + query.scalar_type() != ScalarType::Char, + "Non-flash SDPA does not support quantized (int8) inputs"); + using accum_t = scalar_t; using Vec = vec::Vectorized; accum_t scaling_factor = static_cast(calculate_scale(query, scale)); + // Dimension indices: SeqDim::TWO => [B,H,S,D], SeqDim::ONE => [B,S,H,D] + int64_t q_head_idx = 3 - static_cast(q_seq_dim); + int64_t k_head_idx = 3 - static_cast(k_seq_dim); + int64_t v_head_idx = 3 - static_cast(v_seq_dim); + int64_t batchSize = query.size(0); - int64_t num_head = query.size(1); - int64_t qSize = query.size(2); + int64_t num_head = query.size(q_head_idx); + int64_t qSize = query.size(static_cast(q_seq_dim)); int64_t headSize = query.size(3); - int64_t kvSize = value.size(2); - int64_t num_heads_kv = key.size(1); - - if (seq_dim == SeqDim::ONE) { - num_head = query.size(2); - num_heads_kv = key.size(2); - qSize = query.size(1); - kvSize = value.size(1); - } + int64_t kvSize = key.size(static_cast(k_seq_dim)); + int64_t num_heads_kv = key.size(k_head_idx); if (num_keys_for_causal_attention > 0) { ET_CHECK_MSG( @@ -1126,7 +1136,7 @@ void cpu_sdpa( ET_CHECK_MSG( num_heads_kv <= num_head, - "cpu_sdpa does not support num kv heads > num query heads"); + "cpu_sdpa: num kv heads > num query heads not supported"); ET_CHECK_MSG( num_head % num_heads_kv == 0, "cpu_sdpa: num query heads must be divisible by num kv heads"); @@ -1134,32 +1144,35 @@ void cpu_sdpa( bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); bool is_quantized_sdpa = query.scalar_type() == ScalarType::Char; + if (has_attn_mask) { + ET_CHECK_MSG(attn_mask.value().dim() == 2, "attn_mask must be 2D"); + } - // Extract strides, swapping seq/head dims based on seq_dim - auto q_strides = query.strides(); - int64_t qStrideB = q_strides[0]; - int64_t qStrideH = (seq_dim == SeqDim::ONE) ? q_strides[2] : q_strides[1]; - int64_t qStrideM = (seq_dim == SeqDim::ONE) ? q_strides[1] : q_strides[2]; + // Extract strides (same pattern as cpu_flash_attention) + auto strides = query.strides(); + int64_t qStrideB = strides[0]; + int64_t qStrideH = strides[q_head_idx]; + int64_t qStrideM = strides[static_cast(q_seq_dim)]; - auto k_strides = key.strides(); - int64_t kStrideB = k_strides[0]; - int64_t kStrideH = (seq_dim == SeqDim::ONE) ? k_strides[2] : k_strides[1]; - int64_t kStrideN = (seq_dim == SeqDim::ONE) ? k_strides[1] : k_strides[2]; + strides = key.strides(); + int64_t kStrideB = strides[0]; + int64_t kStrideH = strides[k_head_idx]; + int64_t kStrideN = strides[static_cast(k_seq_dim)]; - auto v_strides = value.strides(); - int64_t vStrideB = v_strides[0]; - int64_t vStrideH = (seq_dim == SeqDim::ONE) ? v_strides[2] : v_strides[1]; - int64_t vStrideN = (seq_dim == SeqDim::ONE) ? v_strides[1] : v_strides[2]; + strides = value.strides(); + int64_t vStrideB = strides[0]; + int64_t vStrideH = strides[v_head_idx]; + int64_t vStrideN = strides[static_cast(v_seq_dim)]; - auto o_strides = output.strides(); - int64_t oStrideB = o_strides[0]; - int64_t oStrideH = (seq_dim == SeqDim::ONE) ? o_strides[2] : o_strides[1]; - int64_t oStrideM = (seq_dim == SeqDim::ONE) ? o_strides[1] : o_strides[2]; + strides = output.strides(); + int64_t oStrideB = strides[0]; + int64_t oStrideH = strides[q_head_idx]; + int64_t oStrideM = strides[static_cast(q_seq_dim)]; int64_t mStrideM = 0; if (has_attn_mask) { - auto m_strides = attn_mask.value().strides(); - mStrideM = m_strides[0]; + strides = attn_mask.value().strides(); + mStrideM = strides[0]; } int64_t q_quant_params_StrideB = 0; @@ -1189,7 +1202,7 @@ void cpu_sdpa( v_quant_params_StrideN = (seq_dim == SeqDim::ONE) ? v_qp_strides[1] : v_qp_strides[2]; } - // Allocate per-thread scores buffer: [qSize, kvSize] per (batch, head) + // Thread count for per-thread scratch allocation #ifdef ET_USE_THREADPOOL int64_t num_thread = ::executorch::extension::threadpool::get_threadpool()->get_thread_count(); @@ -1197,18 +1210,18 @@ void cpu_sdpa( int64_t num_thread = 1; #endif - int64_t scores_per_thread = qSize * kvSize; - int64_t size_bytes = scores_per_thread * num_thread * sizeof(accum_t); + // Allocate scores buffer: one [qSize x kvSize] matrix per thread + int64_t size_per_thread = qSize * kvSize; + int64_t size_bytes = size_per_thread * num_thread * sizeof(accum_t); std::unique_ptr allocated_buf; - void* buf; + accum_t* scores_buf; Result scratch = ctx.allocate_temp(size_bytes, 64); if (!scratch.ok()) { allocated_buf = std::make_unique(size_bytes); - buf = allocated_buf.get(); + scores_buf = reinterpret_cast(allocated_buf.get()); } else { - buf = scratch.get(); + scores_buf = reinterpret_cast(scratch.get()); } - accum_t* buf_data = reinterpret_cast(buf); // Allocate dequantization buffer for V (used by _qk_at_v_gemm when m > 4) int64_t size_per_thread_qdq_vec = kvSize * headSize; @@ -1228,6 +1241,7 @@ void cpu_sdpa( } } + // Data pointers const scalar_t* q_data = query.const_data_ptr(); const scalar_t* k_data = key.const_data_ptr(); const scalar_t* v_data = value.const_data_ptr(); @@ -1235,17 +1249,18 @@ void cpu_sdpa( has_attn_mask ? attn_mask.value().const_data_ptr() : nullptr; scalar_t* out_data = output.mutable_data_ptr(); + // One work-unit per (batch, head) — simpler than flash's (batch, head, block) auto compute_lambda = [&](int64_t begin, int64_t end) { int64_t ompIdx = torch::executor::get_thread_num(); - accum_t* scores = buf_data + ompIdx * scores_per_thread; + accum_t* scores = scores_buf + ompIdx * size_per_thread; accum_t* buf_qdq_ptr = is_quantized_sdpa ? scratch_for_quant_dequant + ompIdx * size_per_thread_qdq_vec : nullptr; - for (int64_t idx = begin; idx < end; ++idx) { - int64_t b = idx / num_head; - int64_t h = idx % num_head; - int64_t kv_h = h / num_reps; + for (int64_t z = begin; z < end; z++) { + int64_t i = z / num_head; // batch index + int64_t j = z % num_head; // head index + int64_t j_kv = j / num_reps; // GQA: map query head to kv head const void* q_ptr; const void* k_ptr; @@ -1257,9 +1272,9 @@ void cpu_sdpa( const int8_t* k_zp_ptr = nullptr; const int8_t* v_zp_ptr = nullptr; - int64_t q_offset = b * qStrideB + h * qStrideH; - int64_t k_offset = b * kStrideB + kv_h * kStrideH; - int64_t v_offset = b * vStrideB + kv_h * vStrideH; + int64_t q_offset = i * qStrideB + j * qStrideH; + int64_t k_offset = i * kStrideB + j_kv * kStrideH; + int64_t v_offset = i * vStrideB + j_kv * vStrideH; if (is_quantized_sdpa) { q_ptr = reinterpret_cast(q_data) + q_offset; @@ -1267,11 +1282,11 @@ void cpu_sdpa( v_ptr = reinterpret_cast(v_data) + v_offset; int64_t q_qp_offset = - b * q_quant_params_StrideB + h * q_quant_params_StrideH; + i * q_quant_params_StrideB + j * q_quant_params_StrideH; int64_t k_qp_offset = - b * k_quant_params_StrideB + kv_h * k_quant_params_StrideH; + i * k_quant_params_StrideB + j_kv * k_quant_params_StrideH; int64_t v_qp_offset = - b * v_quant_params_StrideB + kv_h * v_quant_params_StrideH; + i * v_quant_params_StrideB + j_kv * v_quant_params_StrideH; q_scales_ptr = q_scales.value().const_data_ptr() + q_qp_offset; @@ -1290,7 +1305,7 @@ void cpu_sdpa( k_ptr = k_data + k_offset; v_ptr = v_data + v_offset; } - scalar_t* o_ptr = out_data + b * oStrideB + h * oStrideH; + scalar_t* out_ptr = out_data + i * oStrideB + j * oStrideH; // GEMM 1: scores[qSize, kvSize] = Q[qSize, D] @ K^T[D, kvSize] MaybeQuantizedMatrixData q_matrix( @@ -1354,10 +1369,11 @@ void cpu_sdpa( qSize, headSize, kvSize, scores, kvSize, v_matrix, vStrideN, - o_ptr, oStrideM, + out_ptr, oStrideM, static_cast(0), buf_qdq_ptr); } }; + torch::executor::parallel_for( 0, batchSize * num_head, 1, compute_lambda); }