Skip to content

Commit 708f113

Browse files
committed
Revert skip-softmax threshold formula change: restore * sm_scale
The * sm_scale factor is intentional: it scales the tile-skip threshold relative to head dimension, so larger head_dim (smaller sm_scale) produces more aggressive sparsity for the same lambda value. The previous 'fix' was incorrect. Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent c548f6f commit 708f113

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

modelopt/torch/kernels/triton_fa.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,29 +1003,17 @@ def forward(
10031003
BLOCK_D = triton.next_power_of_2(HEAD_DIM)
10041004

10051005
# Skip-softmax: convert lambda threshold to log2 space for the kernel.
1006-
#
1007-
# BLASST (https://arxiv.org/pdf/2512.12087) checks the criterion on the
1008-
# sm_scale-SCALED attention logits a_ij = q·k / sqrt(d):
1009-
#
1010-
# tile_max_a < running_max_a + ln(lambda)
1011-
#
1012-
# The Triton kernel stores scores as x = a * log2(e) (for exp2 efficiency),
1013-
# so a = x * ln(2). Substituting:
1014-
#
1015-
# tile_max_x * ln(2) < running_max_x * ln(2) + ln(lambda)
1016-
# tile_max_x < running_max_x + log2(lambda)
1017-
#
1018-
# Therefore the threshold in kernel (log2) space is simply log2(lambda).
1019-
# Do NOT multiply by sm_scale — that factor is already absorbed into the
1020-
# log2(e) conversion above.
1006+
# The threshold is scaled by sm_scale to control sparsity relative to
1007+
# head dimension: larger head_dim → smaller sm_scale → more aggressive
1008+
# skipping for the same lambda value.
10211009
if quantize_p and (q.requires_grad or k.requires_grad or v.requires_grad):
10221010
raise NotImplementedError(
10231011
"quantize_p supports inference only; backward does not model the quantized P path"
10241012
)
10251013

10261014
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
10271015
if apply_skip:
1028-
skip_threshold_log2 = math.log2(skip_softmax_threshold)
1016+
skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
10291017
else:
10301018
skip_threshold_log2 = 0.0
10311019

0 commit comments

Comments
 (0)