Skip to content

Commit 05da7ea

Browse files
committed
Add 2:4 sparse attention kernel
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent d78797b commit 05da7ea

14 files changed

Lines changed: 2561 additions & 62 deletions

File tree

examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from modelopt.torch.sparsity.attention_sparsity.config import (
3232
SKIP_SOFTMAX_CALIB,
3333
SKIP_SOFTMAX_DEFAULT,
34+
SPARSE24_TRITON,
3435
)
3536
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
3637

@@ -43,6 +44,7 @@
4344
SPARSE_ATTN_CFG_CHOICES = {
4445
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
4546
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
47+
"sparse24_triton": SPARSE24_TRITON,
4648
}
4749

4850

@@ -144,12 +146,14 @@ def main(args):
144146

145147
print(f"Loading model: {args.pyt_ckpt_path}")
146148

147-
# Load model and tokenizer
148-
# Note: attn_implementation="eager" is required for calibration to work properly
149-
# (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection)
149+
# Select attn_implementation based on sparse method:
150+
# - skip_softmax methods require "eager" (softmax patching bypassed by flash/sdpa)
151+
# - sparse24_triton requires "modelopt_triton" (fused Triton kernel)
152+
# No need to specify attn_implementation here — mtsa.sparsify() handles it
153+
# automatically based on the sparse config (sets "modelopt_triton" for triton
154+
# backend, keeps "eager" for pytorch backend).
150155
model = AutoModelForCausalLM.from_pretrained(
151156
args.pyt_ckpt_path,
152-
attn_implementation="eager",
153157
torch_dtype=torch.bfloat16,
154158
)
155159
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
@@ -246,8 +250,8 @@ def main(args):
246250
"--backend",
247251
type=str,
248252
default="pytorch",
249-
choices=["pytorch"],
250-
help="Backend for sparse attention (default: pytorch). More backends coming soon.",
253+
choices=["pytorch", "triton"],
254+
help="Backend for sparse attention (default: pytorch). Use 'triton' with sparse24_triton.",
251255
)
252256

253257
# Sequence length arguments

modelopt/torch/sparsity/attention_sparsity/config.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
7272
title="Backend implementation.",
7373
description=(
7474
"Backend to use for sparse attention computation. "
75-
"Only 'pytorch' is supported, which uses softmax patching with F.softmax. "
76-
"Requires model to be loaded with attn_implementation='eager'."
75+
"'pytorch' uses softmax patching with F.softmax (requires attn_implementation='eager'). "
76+
"'triton' uses the fused Triton kernel (requires attn_implementation='modelopt_triton')."
7777
),
7878
)
7979

@@ -89,10 +89,20 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
8989
description=(
9090
"Whether the model uses causal (autoregressive) attention. "
9191
"If True, sparsity statistics are calculated over the lower triangle only. "
92+
"Set to False for cross-attention models. "
9293
"Defaults to True for decoder-only models like GPT, LLaMA, etc."
9394
),
9495
)
9596

97+
skip_diagonal_blocks: bool = ModeloptField(
98+
default=True,
99+
title="Skip diagonal blocks.",
100+
description=(
101+
"When True, keep diagonal tiles dense for 2:4 sparse attention. "
102+
"Only used by sparse24_triton method. Defaults to True."
103+
),
104+
)
105+
96106
@field_validator("method")
97107
@classmethod
98108
def validate_method(cls, v):
@@ -104,11 +114,12 @@ def validate_method(cls, v):
104114
@field_validator("backend")
105115
@classmethod
106116
def validate_backend(cls, v):
107-
"""Validate backend is pytorch."""
108-
if v != "pytorch":
117+
"""Validate backend is pytorch or triton."""
118+
if v not in ("pytorch", "triton"):
109119
raise ValueError(
110-
f"Invalid backend: {v}. Only 'pytorch' backend is supported. "
111-
f"Model must be loaded with attn_implementation='eager'."
120+
f"Invalid backend: {v}. Supported backends: 'pytorch' (requires "
121+
f"attn_implementation='eager'), 'triton' (requires "
122+
f"attn_implementation='modelopt_triton')."
112123
)
113124
return v
114125

@@ -416,10 +427,24 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
416427
},
417428
}
418429

430+
# 2:4 structured sparsity via Triton prefill kernel (prefill-only)
431+
SPARSE24_TRITON = {
432+
"sparse_cfg": {
433+
"*attn*": {
434+
"method": "sparse24_triton",
435+
"backend": "triton",
436+
"skip_diagonal_blocks": True,
437+
"enable": True,
438+
},
439+
"default": {"enable": False},
440+
},
441+
}
442+
419443

420444
__all__ = [
421445
"SKIP_SOFTMAX_CALIB",
422446
"SKIP_SOFTMAX_DEFAULT",
447+
"SPARSE24_TRITON",
423448
"CalibrationConfig",
424449
"FlashSkipSoftmaxConfig",
425450
"SparseAttentionAttributeConfig",

modelopt/torch/sparsity/attention_sparsity/conversion.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,37 @@
3232
from .utils import get_named_sparse_attention_modules, get_sparse_attention_modules
3333

3434

35+
def _register_triton_backend_if_needed(model: nn.Module, config: SparseAttentionConfig) -> None:
36+
"""Register the Triton attention backend and set attn_implementation if needed.
37+
38+
When the config uses ``backend="triton"``, this function:
39+
1. Registers the Triton kernel with HF's ``ALL_ATTENTION_FUNCTIONS``.
40+
2. Sets ``model.config._attn_implementation = "modelopt_triton"`` so the
41+
model dispatches to the Triton kernel at forward time.
42+
43+
This is called automatically during ``mtsa.sparsify()`` so users never need
44+
to manually call ``register_triton_attention()`` or set ``attn_implementation``.
45+
"""
46+
sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {}
47+
needs_triton = any(
48+
isinstance(v, dict) and v.get("backend") == "triton" for v in sparse_cfg.values()
49+
)
50+
if not needs_triton:
51+
return
52+
53+
from .kernels import register_triton_attention
54+
55+
if register_triton_attention is not None:
56+
register_triton_attention()
57+
58+
# Set attn_implementation on the model so HF dispatches to the Triton kernel.
59+
# HF's ALL_ATTENTION_FUNCTIONS is checked at forward time, not construction time,
60+
# so this works even after the model is already loaded.
61+
model_config = getattr(model, "config", None)
62+
if model_config is not None:
63+
model_config._attn_implementation = "modelopt_triton"
64+
65+
3566
def is_attn_sparsified(model: nn.Module) -> bool:
3667
"""Check if a model has sparse attention applied.
3768
@@ -61,6 +92,9 @@ def convert_to_sparse_attention_model(
6192
# Initialize the true module if necessary
6293
model = model.init_modellike() if isinstance(model, ModelLikeModule) else model
6394

95+
# Register Triton attention backend and set attn_implementation if needed
96+
_register_triton_backend_if_needed(model, config)
97+
6498
# Apply custom model plugins
6599
register_custom_model_plugins_on_the_fly(model)
66100

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Triton attention kernels for sparse attention optimization."""
17+
18+
import torch
19+
20+
from modelopt.torch.utils import import_plugin
21+
22+
IS_AVAILABLE = False
23+
context_attention_fwd = None
24+
context_attention = None
25+
register_triton_attention = None
26+
set_sparse24 = None
27+
28+
if torch.cuda.is_available():
29+
with import_plugin(
30+
"triton",
31+
msg_if_missing=(
32+
"Your device is potentially capable of using the triton attention "
33+
"kernel. Try to install triton with `pip install triton`."
34+
),
35+
):
36+
from .triton_unified_attention import context_attention as _context_attention
37+
from .triton_unified_attention import context_attention_fwd as _context_attention_fwd
38+
39+
context_attention_fwd = _context_attention_fwd
40+
context_attention = _context_attention
41+
IS_AVAILABLE = True
42+
with import_plugin("transformers"):
43+
from .hf_triton_attention import register_triton_attention as _register_triton_attention
44+
from .hf_triton_attention import set_sparse24 as _set_sparse24
45+
46+
register_triton_attention = _register_triton_attention
47+
set_sparse24 = _set_sparse24
48+
_register_triton_attention()
49+
50+
__all__ = [
51+
"IS_AVAILABLE",
52+
"context_attention",
53+
"context_attention_fwd",
54+
"register_triton_attention",
55+
"set_sparse24",
56+
]

0 commit comments

Comments
 (0)