Skip to content

Commit c234cd2

Browse files
committed
Limit softmax to causally-valid elements in cpu_sdpa
Instead of setting masked positions to -inf and computing max/exp/normalize over all kvSize elements, limit the softmax to only the causally-valid range per row. This avoids unnecessary computation on masked positions and zero-fills them for GEMM 2. Differential Revision: [D96044307](https://our.internmc.facebook.com/intern/diff/D96044307/) ghstack-source-id: 361224791 Pull Request resolved: #18650
1 parent f541c79 commit c234cd2

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,44 +1336,44 @@ void cpu_sdpa(
13361336
k_matrix, kStrideN,
13371337
scores);
13381338

1339-
// Causal mask + scaling + attention mask + softmax per query row
1339+
// Scaling + causal-limited softmax per query row
13401340
for (int64_t qi = 0; qi < qSize; ++qi) {
13411341
accum_t* row = scores + qi * kvSize;
13421342

1343-
// Apply causal mask
1344-
if (is_causal) {
1345-
int64_t valid = std::min(start_pos + qi + 1, kvSize);
1346-
for (int64_t j = valid; j < kvSize; ++j) {
1347-
row[j] = -std::numeric_limits<accum_t>::infinity();
1348-
}
1349-
}
1343+
int64_t num_valid = is_causal
1344+
? std::min(start_pos + qi + 1, kvSize) : kvSize;
13501345

13511346
accum_t max_val;
1352-
const int kvSizeInt = static_cast<int>(kvSize);
1347+
const int num_valid_int = static_cast<int>(num_valid);
13531348
if (has_attn_mask) {
1354-
// Apply scaling factor and attention mask in fusion
1349+
// Apply scaling factor and attention mask over valid range
13551350
const accum_t* mask_row = mask_data + qi * mStrideM;
1356-
for (int64_t j = 0; j < kvSize; ++j) {
1351+
for (int64_t j = 0; j < num_valid; ++j) {
13571352
row[j] = row[j] * scaling_factor + mask_row[j];
13581353
}
13591354
max_val = vec::reduce_all<accum_t>(
13601355
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
1361-
row, kvSize);
1356+
row, num_valid);
13621357
} else {
1363-
// Apply scaling factor and find max in fusion
1358+
// Apply scaling factor and find max over valid range
13641359
_mul_reduce_max_fusion_kernel(
1365-
row, scaling_factor, kvSizeInt, row, max_val);
1360+
row, scaling_factor, num_valid_int, row, max_val);
13661361
}
13671362

13681363
if (max_val == -std::numeric_limits<accum_t>::infinity()) {
13691364
fill_stub(row, static_cast<accum_t>(0), kvSize);
13701365
} else {
13711366
accum_t sum_val = max_val;
1372-
_exp_reduce_sum_fusion_kernel(row, kvSizeInt, row, sum_val);
1367+
_exp_reduce_sum_fusion_kernel(row, num_valid_int, row, sum_val);
13731368
accum_t inv_sum = static_cast<accum_t>(1) / sum_val;
13741369
vec::map<accum_t>(
13751370
[inv_sum](Vec x) { return x * Vec(inv_sum); },
1376-
row, row, kvSize);
1371+
row, row, num_valid);
1372+
// Zero out masked positions for GEMM 2
1373+
if (num_valid < kvSize) {
1374+
fill_stub(
1375+
row + num_valid, static_cast<accum_t>(0), kvSize - num_valid);
1376+
}
13771377
}
13781378
}
13791379

0 commit comments

Comments
 (0)