Skip to content

Commit 8996ef1

Browse files
yeyu-nvidiaclaude
andcommitted
Fix skip-softmax threshold formula: remove erroneous * sm_scale factor
The BLASST (https://arxiv.org/pdf/2512.12087) criterion checks ln(lambda) on the sm_scale-SCALED attention logits a_ij = q·k/sqrt(d). The Triton kernel stores scores as x = a * log2(e), so the correct threshold in kernel (log2) space is log2(lambda), not log2(lambda)*sm_scale. Previous code multiplied by sm_scale (~0.088 for head_dim=128), making every threshold 11× too aggressive. With lambda=0.1 the kernel-space threshold was -0.29 instead of the correct -3.32, skipping most attention tiles and producing garbage output (PSNR~11 dB). Even lambda=0.0001 was still too aggressive (-1.18 vs correct -13.29). Fix: use `log2(lambda)` directly as SKIP_THRESHOLD_LOG2, and restore the default threshold to 0.1 (the standard BLASST value). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 3cb983c commit 8996ef1

2 files changed

Lines changed: 18 additions & 7 deletions

File tree

examples/diffusers/quantization/wan2_sage_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def attention_kernel_ctx(kernel: str = KERNEL_FP8):
450450
}
451451
}
452452

453-
_TRITON_SKIP_DEFAULT_THRESHOLD = 0.01
453+
_TRITON_SKIP_DEFAULT_THRESHOLD = 0.1
454454

455455
_TRITON_SKIP_CONFIG = {
456456
"sparse_cfg": {

modelopt/torch/kernels/triton_fa.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -996,14 +996,25 @@ def forward(
996996
# Triton tiles must be powers of 2; pad head dim
997997
BLOCK_D = triton.next_power_of_2(HEAD_DIM)
998998

999-
# Skip-softmax: convert threshold to scaled log2 space for the kernel.
1000-
# The BLASST reference (https://arxiv.org/pdf/2512.12087) checks
1001-
# ln(lambda) on unscaled scores. Our kernel works in log2-scaled space
1002-
# (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we
1003-
# pre-scale: threshold_scaled = log2(lambda) * sm_scale.
999+
# Skip-softmax: convert lambda threshold to log2 space for the kernel.
1000+
#
1001+
# BLASST (https://arxiv.org/pdf/2512.12087) checks the criterion on the
1002+
# sm_scale-SCALED attention logits a_ij = q·k / sqrt(d):
1003+
#
1004+
# tile_max_a < running_max_a + ln(lambda)
1005+
#
1006+
# The Triton kernel stores scores as x = a * log2(e) (for exp2 efficiency),
1007+
# so a = x * ln(2). Substituting:
1008+
#
1009+
# tile_max_x * ln(2) < running_max_x * ln(2) + ln(lambda)
1010+
# tile_max_x < running_max_x + log2(lambda)
1011+
#
1012+
# Therefore the threshold in kernel (log2) space is simply log2(lambda).
1013+
# Do NOT multiply by sm_scale — that factor is already absorbed into the
1014+
# log2(e) conversion above.
10041015
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
10051016
if apply_skip:
1006-
skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
1017+
skip_threshold_log2 = math.log2(skip_softmax_threshold)
10071018
else:
10081019
skip_threshold_log2 = 0.0
10091020

0 commit comments

Comments
 (0)