Skip to content

Commit ee90ded

Browse files
committed
Update on "Use unfused SDPA for short sequences (q_len <= 128 or kv_len <= 128)"
ATT Differential Revision: [D96044308](https://our.internmc.facebook.com/intern/diff/D96044308/) [ghstack-poisoned]
2 parents cc3c3fe + fbc96c9 commit ee90ded

3 files changed

Lines changed: 5 additions & 4 deletions

File tree

.ci/scripts/test_lora.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,14 @@ else
133133
fi
134134

135135
### QUANTIZATION & PROGRAM DATA SEPARATION ###
136-
EXPECTED_QUANT_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant:
136+
EXPECTED_QUANT_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant: me
137137
<think>
138138
<think>
139139
Okay, so I need to calculate 15% of 80."
140140
EXPECTED_QUANT_LORA_PREFIX="
141141
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
142142
To calculate 15% of 80, we can multiply 80 by 15/100.
143-
80 * 0.15 = 12.
143+
80 * 15/100 = 12.
144144
So, 15% of 80 is 12.
145145
#### 12
146146
The answer is: 12<|im_end|>"

.ci/scripts/test_lora_multimethod.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ EXPECTED_LORA_PREFIX="
6767
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
6868
To calculate 15% of 80"
6969

70-
EXPECTED_BASE_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant:
70+
EXPECTED_BASE_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant: me
7171
<think>
7272
<think>
7373
Okay, so I need to calculate 15% of 80."

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,8 @@ Tensor& custom_sdpa_out_impl(
416416
// correctly when seq_dim=ONE and seq_len > 1, so keep the conservative
417417
// condition for quantized inputs.
418418
bool is_quantized = q.scalar_type() == ScalarType::Char;
419-
bool use_unfused_sdpa = (!is_quantized) && (seq_len <= 128 || num_keys_for_causal_attention <= 128);
419+
bool use_unfused_sdpa = (!is_quantized) &&
420+
(seq_len <= 128 || num_keys_for_causal_attention <= 128);
420421
if (use_unfused_sdpa) {
421422
ET_SWITCH_FLOAT_TYPES(output.scalar_type(), ctx, "sdpa", CTYPE, [&] {
422423
sdpa::impl::cpu_sdpa<CTYPE>(

0 commit comments

Comments
 (0)