Skip to content

Commit 2bad66c

Browse files
authored
[3/n] Add skip-softmax to Triton flash attention kernel (#1081)
### What does this PR do? Type of change: ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> <!-- Details about the change. --> New feature. Add skip-softmax tile skipping to the Triton flash attention kernel. ### Usage ```python # Add a code snippet demonstrating how to use this from modelopt.torch.kernels import attention # Skip-softmax with threshold 0.1 (tiles contributing < 10% are skipped) out = attention(q, k, v, b_start_loc, b_seq_len, max_len, skip_softmax_threshold=0.1) # Via mtsa.sparsify() on HuggingFace models import modelopt.torch.sparsity.attention_sparsity as mtsa from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", torch_dtype=torch.bfloat16, device_map="cuda") # Default config mtsa.sparsify(model, mtsa.SKIP_SOFTMAX_TRITON_DEFAULT) ``` ### Testing <!-- Mention how have you tested your change if applicable. --> Performance (TFLOPS at seq_len=16384, RTX 6000 Pro): | SEQ_LEN | ModelOpt Triton | PyTorch SDPA | Flash Attention 2 | Skip-Softmax t=0.01 | Skip-Softmax t=0.1 | |---:|---:|---:|---:|---:|---:| | 16384.0 | 188.849922 | 211.718193 | 224.242843 | 172.901804 | 279.861684 | | 32768.0 | 175.321787 | 212.815740 | 224.833553 | 146.150702 | 262.490463 | | 65536.0 | 167.302839 | 214.932407 | 226.456141 | 145.082937 | 243.344791 | </body></html> ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a Triton "skip-softmax" tile-skipping option for flash attention with a new attention keyword and configurable threshold (default 0.1). * Added a new sparse attention method and a default sparse configuration that enables the Triton skip-softmax method. * **Tests** * Added GPU tests covering threshold behavior, numerical fidelity vs dense, shape preservation, decode-mode, and integration with sparsify. * **Documentation** * Updated changelog for the new feature and removed two prior listed entries. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent b1f9f01 commit 2bad66c

File tree

10 files changed

+602
-110
lines changed

10 files changed

+602
-110
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ NVIDIA Model Optimizer Changelog
88
- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
99
- Added iterator interface using CalibrationDataReader in ONNX quantization workflow.
1010
- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
11+
- Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
1112
- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml>`_ for more details.
1213

1314
**Bug Fixes**

modelopt/torch/kernels/hf_triton_attention.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,20 @@ def triton_attention_forward(
105105
kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
106106
kw["max_input_len_k"] = seq_k
107107

108-
# N:M sparse softmax — prefill only (decode should not sparsify KV)
109-
if not is_decode and getattr(module, "_apply_sparse_nm", False):
110-
# _sparse_method_instance is set by SparseAttentionModule._init_sparse_method()
111-
# in modelopt/torch/sparsity/attention_sparsity/sparse_attention.py
112-
method = getattr(module, "_sparse_method_instance", None)
113-
if method is not None:
114-
kw["sparsity_n"] = getattr(method, "sparsity_n", 2)
115-
kw["sparsity_m"] = getattr(method, "sparsity_m", 4)
116-
kw["num_sink_tokens"] = getattr(method, "num_sink_tokens", 0)
117-
kw["dense_window_size"] = getattr(method, "dense_window_size", 64)
108+
# Sparse attention params
109+
method = getattr(module, "_sparse_method_instance", None)
110+
111+
# N:M sparse softmax: prefill only (no perf benefit for decode)
112+
if method is not None and not is_decode and getattr(module, "_apply_sparse_nm", False):
113+
kw["sparsity_n"] = method.sparsity_n
114+
kw["sparsity_m"] = method.sparsity_m
115+
kw["num_sink_tokens"] = method.num_sink_tokens
116+
kw["dense_window_size"] = method.dense_window_size
117+
118+
# Skip-softmax: applies to both prefill and decode
119+
if method is not None and getattr(module, "_apply_skip_softmax", False):
120+
if method.skip_softmax_threshold:
121+
kw["skip_softmax_threshold"] = method.skip_softmax_threshold
118122

119123
o = attention(q, k, v, **kw)
120124

modelopt/torch/kernels/triton_fa.py

Lines changed: 134 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
metadata (b_start_loc, b_seq_len). Supports causal masking and autograd.
2424
"""
2525

26+
import math
27+
2628
import torch
2729
import triton
2830
import triton.language as tl
@@ -248,6 +250,8 @@ def _attn_fwd(
248250
SPARSITY_M: tl.constexpr = 4, # N:M sparsity — group size (4 or 8)
249251
NUM_SINK_TOKENS: tl.constexpr = 0, # KV positions before this are kept dense (attention sinks)
250252
DENSE_WINDOW_SIZE: tl.constexpr = 64, # Tokens near diagonal kept dense (absolute, BLOCK_N-independent)
253+
APPLY_SKIP_SOFTMAX: tl.constexpr = False, # Skip KV tiles with negligible scores
254+
SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0, # log2(lambda) * sm_scale, pre-scaled for comparison on scaled scores
251255
):
252256
# --- Grid: (batch, num_q_heads, num_q_tiles) ---
253257
# Example: batch=2, num_q_heads=32, seq_len=256, BLOCK_M=128
@@ -320,26 +324,65 @@ def _attn_fwd(
320324
scores, BLOCK_M, BLOCK_N, SPARSITY_N, SPARSITY_M
321325
)
322326

323-
# --- Online softmax update ---
324-
# 1. Update running max
325-
m_new = tl.maximum(row_max, tl.max(scores, 1))
326-
# 2. Compute unnormalized attention weights
327-
p = tl.math.exp2(scores - m_new[:, None])
328-
l_new = tl.sum(p, 1)
329-
# 3. Correction factor: rescale previous tiles when max changes
330-
correction = tl.math.exp2(row_max - m_new)
331-
row_sum = row_sum * correction + l_new
332-
acc = acc * correction[:, None]
333-
334-
# Load V [BLOCK_N, BLOCK_D] and accumulate: acc += attn_weights @ V
335-
v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :]
336-
v = tl.load(
337-
v_base + v_offs,
338-
mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :],
339-
other=0.0,
340-
)
341-
acc = tl.dot(p.to(v.dtype), v, acc)
342-
row_max = m_new
327+
if APPLY_SKIP_SOFTMAX:
328+
# --- Skip-softmax (BLASST, https://arxiv.org/pdf/2512.12087) ---
329+
#
330+
# Algorithm: During FlashAttention's block-wise computation, we
331+
# maintain a running maximum m_i^(j) across blocks. If a block's
332+
# local maximum ~m_i^(j) is significantly smaller than the running
333+
# maximum m_i^(j):
334+
#
335+
# ~m_i^(j) - m_i^(j) < ln(lambda)
336+
#
337+
# then exp(~m_i^(j) - m_i^(j)) < lambda ≈ 0, meaning the block's
338+
# contribution to the final output is negligible. We skip the
339+
# softmax computation, V load, and BMM2 computation entirely.
340+
#
341+
# The threshold is pre-scaled by qk_scale in the Python wrapper so
342+
# it can be compared directly against scaled scores (matching the
343+
# BLASST reference semantics on unscaled scores).
344+
tile_row_max = tl.max(scores, 1) # [BLOCK_M] — ~m_i^(j) (scaled)
345+
# Per-row: True if row's tile max is negligible vs running max
346+
can_skip = tile_row_max < (row_max + SKIP_THRESHOLD_LOG2)
347+
# Per-tile: skip entire tile only if ALL rows are negligible
348+
skip_tile = tl.min(can_skip.to(tl.int32)) == 1
349+
350+
if not skip_tile:
351+
m_new = tl.maximum(row_max, tile_row_max)
352+
p = tl.math.exp2(scores - m_new[:, None])
353+
l_new = tl.sum(p, 1)
354+
correction = tl.math.exp2(row_max - m_new)
355+
row_sum = row_sum * correction + l_new
356+
acc = acc * correction[:, None]
357+
358+
v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :]
359+
v = tl.load(
360+
v_base + v_offs,
361+
mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :],
362+
other=0.0,
363+
)
364+
acc = tl.dot(p.to(v.dtype), v, acc)
365+
row_max = m_new
366+
# else: tile skipped: no softmax computation, V load, and BMM2 computation
367+
else:
368+
# --- Standard path: no skip check ---
369+
# Online softmax update
370+
m_new = tl.maximum(row_max, tl.max(scores, 1))
371+
p = tl.math.exp2(scores - m_new[:, None])
372+
l_new = tl.sum(p, 1)
373+
correction = tl.math.exp2(row_max - m_new)
374+
row_sum = row_sum * correction + l_new
375+
acc = acc * correction[:, None]
376+
377+
# Load V and accumulate
378+
v_offs = (kv_offset + kv_start + kv_pos[:, None]) * stride_vbs + dim_pos[None, :]
379+
v = tl.load(
380+
v_base + v_offs,
381+
mask=((kv_start + kv_pos[:, None]) < seq_len_kv) & d_mask[None, :],
382+
other=0.0,
383+
)
384+
acc = tl.dot(p.to(v.dtype), v, acc)
385+
row_max = m_new
343386

344387
# --- Final normalization: output = acc / row_sum ---
345388
acc = acc / row_sum[:, None]
@@ -440,6 +483,8 @@ def _attn_bwd_dq(
440483
SPARSITY_M: tl.constexpr = 4,
441484
NUM_SINK_TOKENS: tl.constexpr = 0,
442485
DENSE_WINDOW_SIZE: tl.constexpr = 64,
486+
APPLY_SKIP_SOFTMAX: tl.constexpr = False,
487+
SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0,
443488
):
444489
"""Phase 3 of backward: compute dQ for one Q tile, looping over KV tiles.
445490
@@ -523,6 +568,16 @@ def _attn_bwd_dq(
523568

524569
p = tl.math.exp2(scores - lse[:, None])
525570

571+
# Skip-softmax backward: zero out P for rows with negligible contribution.
572+
# Per-row using final LSE because forward/backward tile sizes may differ
573+
# (forward autotunes BLOCK_N; backward uses a fixed size), so per-tile
574+
# skip masks from forward wouldn't align. LSE >= any intermediate running
575+
# max, so this conservatively zeros out at least what forward skipped.
576+
if APPLY_SKIP_SOFTMAX:
577+
tile_row_max = tl.max(scores, 1)
578+
can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2)
579+
p = tl.where(can_skip[:, None], 0.0, p)
580+
526581
# dP = dO @ V^T, dS = P * (dP - delta), dQ += dS @ K
527582
dp = tl.dot(do, tl.trans(v))
528583
ds = p * (dp - row_delta[:, None])
@@ -574,6 +629,8 @@ def _attn_bwd_dkdv(
574629
SPARSITY_M: tl.constexpr = 4,
575630
NUM_SINK_TOKENS: tl.constexpr = 0,
576631
DENSE_WINDOW_SIZE: tl.constexpr = 64,
632+
APPLY_SKIP_SOFTMAX: tl.constexpr = False,
633+
SKIP_THRESHOLD_LOG2: tl.constexpr = 0.0,
577634
):
578635
"""Phase 2 of backward: compute dK, dV for one KV tile.
579636
@@ -665,6 +722,16 @@ def _attn_bwd_dkdv(
665722

666723
p = tl.math.exp2(scores - lse[:, None])
667724

725+
# Skip-softmax backward: zero out P for rows with negligible contribution.
726+
# Per-row using final LSE because forward/backward tile sizes may differ
727+
# (forward autotunes BLOCK_N; backward uses a fixed size), so per-tile
728+
# skip masks from forward wouldn't align. LSE >= any intermediate running
729+
# max, so this conservatively zeros out at least what forward skipped.
730+
if APPLY_SKIP_SOFTMAX:
731+
tile_row_max = tl.max(scores, 1)
732+
can_skip = tile_row_max < (lse + SKIP_THRESHOLD_LOG2)
733+
p = tl.where(can_skip[:, None], 0.0, p)
734+
668735
# dV += P^T @ dO
669736
dv += tl.dot(tl.trans(p.to(do_tile.dtype)), do_tile)
670737
# dS = P * (dO @ V^T - delta), dK += dS^T @ Q
@@ -700,6 +767,7 @@ def forward(
700767
sparsity_m,
701768
num_sink_tokens,
702769
dense_window_size,
770+
skip_softmax_threshold,
703771
):
704772
HEAD_DIM = q.shape[2]
705773
num_q_heads = q.shape[1]
@@ -720,6 +788,17 @@ def forward(
720788
# Triton tiles must be powers of 2; pad head dim
721789
BLOCK_D = triton.next_power_of_2(HEAD_DIM)
722790

791+
# Skip-softmax: convert threshold to scaled log2 space for the kernel.
792+
# The BLASST reference (https://arxiv.org/pdf/2512.12087) checks
793+
# ln(lambda) on unscaled scores. Our kernel works in log2-scaled space
794+
# (scores pre-multiplied by qk_scale = sm_scale * LOG2E), so we
795+
# pre-scale: threshold_scaled = log2(lambda) * sm_scale.
796+
apply_skip = skip_softmax_threshold is not None and skip_softmax_threshold > 0.0
797+
if apply_skip:
798+
skip_threshold_log2 = math.log2(skip_softmax_threshold) * sm_scale
799+
else:
800+
skip_threshold_log2 = 0.0
801+
723802
o = torch.empty_like(q)
724803
lse = torch.empty(q.shape[0], num_q_heads, device=q.device, dtype=torch.float32)
725804

@@ -758,6 +837,8 @@ def grid(META):
758837
SPARSITY_M=sparsity_m,
759838
NUM_SINK_TOKENS=num_sink_tokens,
760839
DENSE_WINDOW_SIZE=dense_window_size,
840+
APPLY_SKIP_SOFTMAX=apply_skip,
841+
SKIP_THRESHOLD_LOG2=skip_threshold_log2,
761842
# BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune
762843
)
763844

@@ -776,6 +857,8 @@ def grid(META):
776857
ctx.sparsity_m = sparsity_m
777858
ctx.num_sink_tokens = num_sink_tokens
778859
ctx.dense_window_size = dense_window_size
860+
ctx.apply_skip = apply_skip
861+
ctx.skip_threshold_log2 = skip_threshold_log2
779862
return o
780863

781864
@staticmethod
@@ -854,6 +937,8 @@ def backward(ctx, grad_output):
854937
SPARSITY_M=ctx.sparsity_m,
855938
NUM_SINK_TOKENS=ctx.num_sink_tokens,
856939
DENSE_WINDOW_SIZE=ctx.dense_window_size,
940+
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
941+
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
857942
num_warps=num_warps,
858943
num_stages=1,
859944
)
@@ -877,11 +962,30 @@ def backward(ctx, grad_output):
877962
SPARSITY_M=ctx.sparsity_m,
878963
NUM_SINK_TOKENS=ctx.num_sink_tokens,
879964
DENSE_WINDOW_SIZE=ctx.dense_window_size,
965+
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
966+
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
880967
num_warps=num_warps,
881968
num_stages=1,
882969
)
883970

884-
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
971+
return (
972+
dq,
973+
dk,
974+
dv,
975+
None,
976+
None,
977+
None,
978+
None,
979+
None,
980+
None,
981+
None,
982+
None,
983+
None,
984+
None,
985+
None,
986+
None,
987+
None,
988+
)
885989

886990

887991
def attention(
@@ -901,8 +1005,9 @@ def attention(
9011005
sparsity_m: int = 4,
9021006
num_sink_tokens: int = 0,
9031007
dense_window_size: int = 64,
1008+
skip_softmax_threshold: float | None = None,
9041009
) -> torch.Tensor:
905-
"""Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax.
1010+
"""Variable-length flash attention with GQA, autograd, and optional N:M sparse softmax and skip-softmax.
9061011
9071012
Args:
9081013
q: [total_q_tokens, num_q_heads, head_dim]
@@ -926,6 +1031,12 @@ def attention(
9261031
dense_window_size: Tokens near the query diagonal kept dense (local
9271032
attention window). Absolute token count, BLOCK_N-independent.
9281033
Default 64 (one reference block).
1034+
skip_softmax_threshold: BLASST threshold lambda
1035+
(https://arxiv.org/pdf/2512.12087). Skip KV tiles where
1036+
``exp(tile_max - running_max) < lambda``, meaning the tile's
1037+
softmax contribution is negligible. Tiles are skipped entirely
1038+
(no softmax, V load, or BMM2). The threshold is applied on
1039+
unscaled scores. Set to ``None`` or ``0`` to disable.
9291040
9301041
Returns:
9311042
Output tensor [total_q_tokens, num_q_heads, head_dim].
@@ -947,6 +1058,7 @@ def attention(
9471058
sparsity_m,
9481059
num_sink_tokens,
9491060
dense_window_size,
1061+
skip_softmax_threshold,
9501062
)
9511063

9521064

modelopt/torch/sparsity/attention_sparsity/config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,16 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
129129
),
130130
)
131131

132+
skip_softmax_threshold: float = ModeloptField(
133+
default=0.1,
134+
title="Skip-softmax threshold.",
135+
description=(
136+
"Tiles contributing less than this fraction are skipped entirely. "
137+
"Only used by triton_skip_softmax. Typical values: 1e-3 to 1e-1. "
138+
"Set to 0 to disable."
139+
),
140+
)
141+
132142
@field_validator("method")
133143
@classmethod
134144
def validate_method(cls, v):
@@ -528,9 +538,24 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
528538
}
529539

530540

541+
# Default skip-softmax configuration for Triton kernel
542+
SKIP_SOFTMAX_TRITON_DEFAULT = {
543+
"sparse_cfg": {
544+
"*attn*": {
545+
"method": "triton_skip_softmax",
546+
"skip_softmax_threshold": 0.1,
547+
"backend": "triton",
548+
"enable": True,
549+
},
550+
"default": {"enable": False},
551+
},
552+
}
553+
554+
531555
__all__ = [
532556
"SKIP_SOFTMAX_CALIB",
533557
"SKIP_SOFTMAX_DEFAULT",
558+
"SKIP_SOFTMAX_TRITON_DEFAULT",
534559
"SPARSE_SOFTMAX_DEFAULT",
535560
"CalibrationConfig",
536561
"FlashSkipSoftmaxConfig",

modelopt/torch/sparsity/attention_sparsity/methods/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@
2424
]
2525

2626
# Import method implementations to trigger registration
27-
from . import flash_skip_softmax, triton_sparse_softmax
27+
from . import flash_skip_softmax, triton_skip_softmax, triton_sparse_softmax

modelopt/torch/sparsity/attention_sparsity/methods/registry.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,7 @@ def apply_sparsity(
7676
Returns:
7777
Masked attention scores with sparse elements set to -inf
7878
"""
79-
raise NotImplementedError(
80-
f"{type(self).__name__} does not implement apply_sparsity. "
81-
"Sparsity may be fused into the kernel (Triton backend)."
82-
)
79+
raise NotImplementedError(f"{type(self).__name__} does not implement apply_sparsity.")
8380

8481
def get_sparse_context(self, module: torch.nn.Module):
8582
"""Return a context manager that activates this method's sparsity during forward.

0 commit comments

Comments
 (0)