Skip to content

Commit 93260ca

Browse files
committed
Use unfused SDPA for short sequences (q_len <= 128 or kv_len <= 128)
Pull Request resolved: #18651 ATT ghstack-source-id: 373439495 @exported-using-ghexport Differential Revision: [D96044308](https://our.internmc.facebook.com/intern/diff/D96044308/)
1 parent c73eb2d commit 93260ca

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,11 @@ Tensor& custom_sdpa_out_impl(
478478
InvalidArgument,
479479
output);
480480

481-
bool use_unfused_sdpa = seq_len == 1;
481+
// Quantized GEMM kernels may not handle non-contiguous per-head strides
482+
// correctly when seq_dim=ONE and seq_len > 1, so keep the conservative
483+
// condition for quantized inputs.
484+
bool is_quantized = q.scalar_type() == ScalarType::Char;
485+
bool use_unfused_sdpa = (!is_quantized) && (seq_len <= 128 || num_keys_for_causal_attention <= 128);
482486
if (use_unfused_sdpa) {
483487
ET_SWITCH_FLOAT_TYPES(output.scalar_type(), ctx, "sdpa", CTYPE, [&] {
484488
sdpa::impl::cpu_sdpa<CTYPE>(

0 commit comments

Comments
 (0)