Skip to content

Commit e4ba4cf

Browse files
committed
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_len <= 128)"
ATT Differential Revision: [D96044308](https://our.internmc.facebook.com/intern/diff/D96044308/) [ghstack-poisoned]
2 parents 98ebd08 + 5c4c8e6 commit e4ba4cf

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ Tensor& custom_sdpa_out_impl(
416416
// correctly when seq_dim=ONE and seq_len > 1, so keep the conservative
417417
// condition for quantized inputs.
418418
bool is_quantized = q.scalar_type() == ScalarType::Char;
419-
bool use_unfused_sdpa = is_quantized && (seq_len <= 128 || num_keys_for_causal_attention <= 128);
419+
bool use_unfused_sdpa = (!is_quantized) && (seq_len <= 128 || num_keys_for_causal_attention <= 128);
420420
if (use_unfused_sdpa) {
421421
ET_SWITCH_FLOAT_TYPES(output.scalar_type(), ctx, "sdpa", CTYPE, [&] {
422422
sdpa::impl::cpu_sdpa<CTYPE>(

0 commit comments

Comments
 (0)