diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index ed638d86ff6..df6abb79166 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -1349,42 +1349,47 @@ void cpu_sdpa( kStrideN, scores); - // Causal mask + scaling + attention mask + softmax per query row + // Scaling + causal-limited softmax per query row for (int64_t qi = 0; qi < qSize; ++qi) { accum_t* row = scores + qi * kvSize; - // Apply causal mask - if (is_causal) { - int64_t valid = std::min(start_pos + qi + 1, kvSize); - for (int64_t j = valid; j < kvSize; ++j) { - row[j] = -std::numeric_limits::infinity(); - } - } + int64_t num_valid = + is_causal ? std::min(start_pos + qi + 1, kvSize) : kvSize; accum_t max_val; - const int kvSizeInt = static_cast(kvSize); + const int num_valid_int = static_cast(num_valid); if (has_attn_mask) { - // Apply scaling factor and attention mask in fusion + // Apply scaling factor and attention mask over valid range const accum_t* mask_row = mask_data + qi * mStrideM; - for (int64_t j = 0; j < kvSize; ++j) { + for (int64_t j = 0; j < num_valid; ++j) { row[j] = row[j] * scaling_factor + mask_row[j]; } max_val = vec::reduce_all( - [](Vec& x, Vec& y) { return vec::maximum(x, y); }, row, kvSize); + [](Vec& x, Vec& y) { return vec::maximum(x, y); }, + row, + num_valid); } else { - // Apply scaling factor and find max in fusion + // Apply scaling factor and find max over valid range _mul_reduce_max_fusion_kernel( - row, scaling_factor, kvSizeInt, row, max_val); + row, scaling_factor, num_valid_int, row, max_val); } if (max_val == -std::numeric_limits::infinity()) { fill_stub(row, static_cast(0), kvSize); } else { accum_t sum_val = max_val; - _exp_reduce_sum_fusion_kernel(row, kvSizeInt, row, sum_val); + _exp_reduce_sum_fusion_kernel(row, num_valid_int, row, sum_val); accum_t inv_sum = static_cast(1) / sum_val; vec::map( - [inv_sum](Vec x) { return x * Vec(inv_sum); }, row, row, kvSize); + [inv_sum](Vec x) { return x * Vec(inv_sum); }, + row, + row, + num_valid); + // Zero out masked positions for GEMM 2 + if (num_valid < kvSize) { + fill_stub( + row + num_valid, static_cast(0), kvSize - num_valid); + } } }