Skip to content

Commit 249bcb0

Browse files
committed
Limit softmax to causally-valid elements in cpu_sdpa
Pull Request resolved: #18650 Instead of setting masked positions to -inf and computing max/exp/normalize over all kvSize elements, limit the softmax to only the causally-valid range per row. This avoids unnecessary computation on masked positions and zero-fills them for GEMM 2. ghstack-source-id: 374666321 @exported-using-ghexport Differential Revision: [D96044307](https://our.internmc.facebook.com/intern/diff/D96044307/)
1 parent 3d4ca38 commit 249bcb0

1 file changed

Lines changed: 21 additions & 16 deletions

File tree

extension/llm/custom_ops/op_sdpa_impl.h

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,42 +1349,47 @@ void cpu_sdpa(
13491349
kStrideN,
13501350
scores);
13511351

1352-
// Causal mask + scaling + attention mask + softmax per query row
1352+
// Scaling + causal-limited softmax per query row
13531353
for (int64_t qi = 0; qi < qSize; ++qi) {
13541354
accum_t* row = scores + qi * kvSize;
13551355

1356-
// Apply causal mask
1357-
if (is_causal) {
1358-
int64_t valid = std::min(start_pos + qi + 1, kvSize);
1359-
for (int64_t j = valid; j < kvSize; ++j) {
1360-
row[j] = -std::numeric_limits<accum_t>::infinity();
1361-
}
1362-
}
1356+
int64_t num_valid =
1357+
is_causal ? std::min(start_pos + qi + 1, kvSize) : kvSize;
13631358

13641359
accum_t max_val;
1365-
const int kvSizeInt = static_cast<int>(kvSize);
1360+
const int num_valid_int = static_cast<int>(num_valid);
13661361
if (has_attn_mask) {
1367-
// Apply scaling factor and attention mask in fusion
1362+
// Apply scaling factor and attention mask over valid range
13681363
const accum_t* mask_row = mask_data + qi * mStrideM;
1369-
for (int64_t j = 0; j < kvSize; ++j) {
1364+
for (int64_t j = 0; j < num_valid; ++j) {
13701365
row[j] = row[j] * scaling_factor + mask_row[j];
13711366
}
13721367
max_val = vec::reduce_all<accum_t>(
1373-
[](Vec& x, Vec& y) { return vec::maximum(x, y); }, row, kvSize);
1368+
[](Vec& x, Vec& y) { return vec::maximum(x, y); },
1369+
row,
1370+
num_valid);
13741371
} else {
1375-
// Apply scaling factor and find max in fusion
1372+
// Apply scaling factor and find max over valid range
13761373
_mul_reduce_max_fusion_kernel(
1377-
row, scaling_factor, kvSizeInt, row, max_val);
1374+
row, scaling_factor, num_valid_int, row, max_val);
13781375
}
13791376

13801377
if (max_val == -std::numeric_limits<accum_t>::infinity()) {
13811378
fill_stub(row, static_cast<accum_t>(0), kvSize);
13821379
} else {
13831380
accum_t sum_val = max_val;
1384-
_exp_reduce_sum_fusion_kernel(row, kvSizeInt, row, sum_val);
1381+
_exp_reduce_sum_fusion_kernel(row, num_valid_int, row, sum_val);
13851382
accum_t inv_sum = static_cast<accum_t>(1) / sum_val;
13861383
vec::map<accum_t>(
1387-
[inv_sum](Vec x) { return x * Vec(inv_sum); }, row, row, kvSize);
1384+
[inv_sum](Vec x) { return x * Vec(inv_sum); },
1385+
row,
1386+
row,
1387+
num_valid);
1388+
// Zero out masked positions for GEMM 2
1389+
if (num_valid < kvSize) {
1390+
fill_stub(
1391+
row + num_valid, static_cast<accum_t>(0), kvSize - num_valid);
1392+
}
13881393
}
13891394
}
13901395

0 commit comments

Comments
 (0)