Skip to content

Commit 79a8168

Browse files
committed
Use unfused SDPA for short sequences (q_len <= 128 or kv_len <= 128)
ATT Differential Revision: [D96044308](https://our.internmc.facebook.com/intern/diff/D96044308/) ghstack-source-id: 361224789 Pull Request resolved: #18651
1 parent c234cd2 commit 79a8168

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,13 @@ Tensor& custom_sdpa_out_impl(
412412
InvalidArgument,
413413
output);
414414

415-
bool use_unfused_sdpa = seq_len == 1;
415+
// Quantized GEMM kernels may not handle non-contiguous per-head strides
416+
// correctly when seq_dim=ONE and seq_len > 1, so keep the conservative
417+
// condition for quantized inputs.
418+
bool is_quantized = q.scalar_type() == ScalarType::Char;
419+
bool use_unfused_sdpa = is_quantized
420+
? (seq_len == 1)
421+
: (seq_len <= 128 || num_keys_for_causal_attention <= 128);
416422
if (use_unfused_sdpa) {
417423
ET_SWITCH_FLOAT_TYPES(
418424
output.scalar_type(), ctx, "sdpa", CTYPE, [&] {

0 commit comments

Comments
 (0)