Skip to content

Commit 77d256e

Browse files
committed
Use unfused SDPA for short sequences (q_len <= 128 or kv_len <= 128)
Pull Request resolved: #18651 ATT ghstack-source-id: 378132080 @exported-using-ghexport Differential Revision: [D96044308](https://our.internmc.facebook.com/intern/diff/D96044308/)
1 parent da42a4d commit 77d256e

3 files changed

Lines changed: 12 additions & 8 deletions

File tree

.ci/scripts/test_lora.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,12 @@ else
151151
fi
152152

153153
### QUANTIZATION & PROGRAM DATA SEPARATION ###
154-
EXPECTED_QUANT_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant:
155-
<think>
154+
EXPECTED_QUANT_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant: me
156155
Okay, so I need to calculate 15% of 80."
157156
EXPECTED_QUANT_LORA_PREFIX="
158157
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
159-
To calculate 15% of 80, we can multiply 80 by 15/100 and then simplify the fraction.
160-
So, 15% of 80 is equal to (80 * 15) / 100 = 1200 / 100 = 12.
158+
To calculate 15% of 80, we can multiply 80 by 15/100.
159+
So, 15% of 80 is equal to 80 * 15/100 = 12.
161160
#### 12
162161
The answer is: 12<|im_end|>"
163162
EXPECTED_QUANT_LORA_ALTERNATE_PREFIX="
@@ -169,6 +168,7 @@ So, 15% of 80 is 12.
169168
The answer is: 12<|im_end|>"
170169

171170

171+
172172
# Export Quantized PTE, PTD file, no LoRA.
173173
# override base.lora_config=null to avoid creating a lora model
174174
# and loading lora weights.
@@ -228,7 +228,7 @@ fi
228228
NOW=$(date +"%H:%M:%S")
229229
echo "Test 4: Quantized, program-data separation lora. Starting to run llama runner at ${NOW}"
230230
# shellcheck source=/dev/null
231-
cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_q.pte --data_paths="qwen_foundation_q.ptd,qwen_lora_math_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} --seq_len=104 > result.txt
231+
cmake-out/examples/models/llama/llama_main --model_path=qwen_lora_math_q.pte --data_paths="qwen_foundation_q.ptd,qwen_lora_math_q.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
232232
NOW=$(date +"%H:%M:%S")
233233
echo "Finished at ${NOW}"
234234

.ci/scripts/test_lora_multimethod.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,7 @@ EXPECTED_LORA_PREFIX="
8585
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
8686
To calculate 15% of 80"
8787

88-
EXPECTED_BASE_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant:
89-
<think>
88+
EXPECTED_BASE_PREFIX="<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant: me
9089
Okay, so I need to calculate 15% of 80."
9190

9291
### TEST 1: Run lora_forward method ###

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,12 @@ Tensor& custom_sdpa_out_impl(
412412
InvalidArgument,
413413
output);
414414

415-
bool use_unfused_sdpa = seq_len == 1;
415+
// Quantized GEMM kernels may not handle non-contiguous per-head strides
416+
// correctly when seq_dim=ONE and seq_len > 1, so keep the conservative
417+
// condition for quantized inputs.
418+
bool is_quantized = q.scalar_type() == ScalarType::Char;
419+
bool use_unfused_sdpa = (!is_quantized) &&
420+
(seq_len <= 128 || num_keys_for_causal_attention <= 128);
416421
if (use_unfused_sdpa) {
417422
ET_SWITCH_FLOAT_TYPES(output.scalar_type(), ctx, "sdpa", CTYPE, [&] {
418423
sdpa::impl::cpu_sdpa<CTYPE>(

0 commit comments

Comments
 (0)