Skip to content

Commit 7c33d85

Browse files
authored
[1/n] Add a Triton attention kernel with HF integration (#1034)
### What does this PR do? Type of change: ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> <!-- Details about the change. --> - Adds a Triton flash attention kernel (triton_fa.py) with HF integration for use in sparse attention and quantization workflows. The kernel implements Flash Attention with varlen support, GQA, causal masking, and forward/backward. - Update the sparse attention to support backend="triton". Key components: - modelopt/torch/kernels/triton_fa.py -- Core Triton kernel - modelopt/torch/kernels/hf_triton_attention.py -- HF adapter, registered as attn_implementation="modelopt_triton" - modelopt/torch/kernels/__init__.py -- Shared kernel registry - modelopt/torch/sparsity/attention_sparsity/conversion.py -- Backend selection (backend="triton" or "pytorch") - modelopt/torch/sparsity/attention_sparsity/config.py -- Added "triton" as valid backend option - examples/llm_sparsity/attention_sparsity/hf_sa.py -- Updated example to support --backend triton ### Usage ```python # Direct kernel API (varlen packed format) from modelopt.torch.kernels import attention o = attention( q, k, v, # [total_tokens, heads, head_dim] b_start_loc=b_start_loc, # [batch] per-sequence start offsets b_seq_len=b_seq_len, # [batch] per-sequence lengths max_input_len=max_seq_len, is_causal=True, ) # HuggingFace integration (automatic via sparsify) import modelopt.torch.sparsity.attention_sparsity as mtsa config = {"sparse_cfg": {"*attn*": {"method": "flash_skip_softmax", "backend": "triton", "enable": True}}} model = mtsa.sparsify(model, config=config) # model now uses the Triton kernel for attention # Or load directly with attn_implementation model = AutoModelForCausalLM.from_pretrained(path, attn_implementation="modelopt_triton") ``` ### Testing <!-- Mention how have you tested your change if applicable. --> `tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` #### Kernel benchmark on RTX 6000 | SEQ_LEN | ModelOpt Triton | PyTorch SDPA | Flash Attention 2 | |--------:|----------------:|-------------:|------------------:| | 256.0 | 34.435353 | 26.199215 | 47.927293 | | 512.0 | 60.216998 | 47.408218 | 80.736116 | | 1024.0 | 81.209990 | 82.673526 | 94.197181 | | 2048.0 | 88.800973 | 89.239451 | 94.822496 | | 4096.0 | 88.302953 | 89.192071 | 96.826178 | | 8192.0 | 89.538177 | 89.115835 | 91.563461 | | 16384.0 | 85.457533 | 80.509254 | 81.391092 | ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added Triton flash attention backend support for sparse attention operations, alongside the existing PyTorch backend, enabling improved performance for compatible hardware. * **Documentation** * Updated README to document both available attention backends and their configurations. * **Tests** * Added comprehensive test coverage for the new Triton backend, including forward and backward pass validation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 00fa5bd commit 7c33d85

File tree

14 files changed

+1498
-95
lines changed

14 files changed

+1498
-95
lines changed

examples/llm_sparsity/attention_sparsity/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Attention Sparsity for HuggingFace Models
22

3-
In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation.
3+
In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation. Two attention backends are supported:
4+
5+
- **pytorch** (default): Patches `F.softmax` to apply skip-softmax sparsity (requires `attn_implementation="eager"`)
6+
- **triton**: Uses a fused Triton Flash Attention kernel with in-kernel sparsity (uses `attn_implementation="modelopt_triton"`)
47

58
## Getting Started
69

examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,8 @@ def main(args):
144144

145145
print(f"Loading model: {args.pyt_ckpt_path}")
146146

147-
# Load model and tokenizer
148-
# Note: attn_implementation="eager" is required for calibration to work properly
149-
# (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection)
147+
# No need to specify attn_implementation here — mtsa.sparsify() sets it
148+
# automatically ("eager" for pytorch backend, "modelopt_triton" for triton).
150149
model = AutoModelForCausalLM.from_pretrained(
151150
args.pyt_ckpt_path,
152151
attn_implementation="eager",
@@ -164,21 +163,21 @@ def main(args):
164163
output_before, test_prompt, input_ids = generate_sample_output(model, tokenizer, args)
165164

166165
# Apply sparse attention with optional calibration
167-
print(f"\nApplying sparse attention: {args.sparse_attn}")
168-
sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]
169-
170-
# Override calibration options if provided via CLI
166+
print(f"\nApplying sparse attention: {args.sparse_attn} (backend={args.backend})")
167+
sparse_config = copy.deepcopy(SPARSE_ATTN_CFG_CHOICES[args.sparse_attn])
168+
169+
# Apply CLI overrides to sparse_cfg
170+
sparse_cfg = sparse_config.get("sparse_cfg", {})
171+
for layer_cfg in sparse_cfg.values():
172+
if isinstance(layer_cfg, dict) and "method" in layer_cfg:
173+
layer_cfg["backend"] = args.backend
171174
if args.target_sparse_ratio is not None:
172-
sparse_config = copy.deepcopy(sparse_config)
173-
sparse_cfg = sparse_config.get("sparse_cfg", {})
174-
if isinstance(sparse_cfg, dict) and "calibration" in sparse_cfg:
175-
calibration_cfg = sparse_cfg["calibration"]
176-
if isinstance(calibration_cfg, dict):
177-
calibration_cfg["target_sparse_ratio"] = {
178-
"prefill": args.target_sparse_ratio,
179-
"decode": args.target_sparse_ratio,
180-
}
181-
print(f"Overriding target_sparse_ratio to {args.target_sparse_ratio}")
175+
calib = sparse_cfg.setdefault("calibration", {})
176+
assert isinstance(calib, dict)
177+
calib["target_sparse_ratio"] = {
178+
"prefill": args.target_sparse_ratio,
179+
"decode": args.target_sparse_ratio,
180+
}
182181

183182
model = mtsa.sparsify(model, config=sparse_config)
184183
print("Sparse attention applied successfully!")
@@ -242,8 +241,8 @@ def main(args):
242241
"--backend",
243242
type=str,
244243
default="pytorch",
245-
choices=["pytorch"],
246-
help="Backend for sparse attention (default: pytorch). More backends coming soon.",
244+
choices=["pytorch", "triton"],
245+
help="Backend for sparse attention (default: pytorch). 'triton' uses the fused Triton kernel.",
247246
)
248247

249248
# Sequence length arguments

modelopt/torch/kernels/__init__.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Shared Triton kernels for modelopt (attention, quantization, etc.)."""
17+
18+
import torch
19+
20+
from modelopt.torch.utils import import_plugin
21+
22+
IS_AVAILABLE = False
23+
attention = None
24+
register_triton_attention = None
25+
26+
if torch.cuda.is_available():
27+
with import_plugin(
28+
"triton",
29+
msg_if_missing=(
30+
"Your device is potentially capable of using the triton attention "
31+
"kernel. Try to install triton with `pip install triton`."
32+
),
33+
):
34+
from .triton_fa import attention as _attention
35+
36+
attention = _attention
37+
IS_AVAILABLE = True
38+
with import_plugin("transformers"):
39+
from .hf_triton_attention import register_triton_attention as _register_triton_attention
40+
41+
register_triton_attention = _register_triton_attention
42+
43+
__all__ = [
44+
"IS_AVAILABLE",
45+
"attention",
46+
"register_triton_attention",
47+
]
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""HuggingFace attention backend using the Triton flash attention kernel.
17+
18+
Registers as attn_implementation="modelopt_triton" so HF models dispatch to the
19+
Triton kernel natively. Handles format conversion between HF's [batch, heads, seq, dim]
20+
and the kernel's flat packed [total_tokens, heads, dim] varlen format.
21+
"""
22+
23+
from __future__ import annotations
24+
25+
import torch
26+
import torch.nn as nn
27+
28+
from modelopt.torch.kernels.triton_fa import attention
29+
30+
31+
def _seq_lens_from_mask(
32+
attention_mask: torch.Tensor | None,
33+
fallback: int,
34+
device: torch.device,
35+
) -> tuple[torch.Tensor | None, bool]:
36+
"""Derive per-sequence lengths from attention mask.
37+
38+
Returns (b_seq_len, has_padding). If the mask is not a usable 2D format,
39+
returns (None, False).
40+
"""
41+
if attention_mask is not None and attention_mask.dim() == 2:
42+
mask = attention_mask.bool() if attention_mask.dtype != torch.bool else attention_mask
43+
b_seq_len = mask.sum(dim=1).to(torch.int32).to(device)
44+
has_padding = bool((b_seq_len != fallback).any())
45+
return b_seq_len, has_padding
46+
return None, False
47+
48+
49+
def triton_attention_forward(
50+
module: nn.Module,
51+
query: torch.Tensor,
52+
key: torch.Tensor,
53+
value: torch.Tensor,
54+
attention_mask: torch.Tensor | None,
55+
scaling: float,
56+
dropout: float = 0.0,
57+
**kwargs,
58+
) -> tuple[torch.Tensor, None]:
59+
"""Attention forward compatible with HF AttentionInterface.
60+
61+
Converts HF tensors to varlen format, calls the Triton kernel, converts back.
62+
Handles both prefill (seq_len > 1) and decode (seq_len == 1).
63+
64+
Args:
65+
module: The attention module (LlamaAttention etc.).
66+
query: [batch, num_heads, seq_len, head_dim].
67+
key: [batch, num_kv_heads, seq_k, head_dim].
68+
value: [batch, num_kv_heads, seq_k, head_dim].
69+
attention_mask: Optional; kernel handles causal masking internally.
70+
2D [batch, seq_len] masks are used to derive per-sequence lengths.
71+
Other formats (e.g. 4D causal masks) are ignored.
72+
scaling: Softmax scale (e.g. 1/sqrt(head_dim)).
73+
dropout: Ignored (kernel has no dropout); use 0 for eval.
74+
**kwargs: Reserved for future extensions.
75+
76+
Returns:
77+
(attn_output, None) with attn_output [batch, seq_len, num_heads, head_dim].
78+
"""
79+
batch, num_heads, seq_len, head_dim = query.shape
80+
seq_k = key.shape[2]
81+
num_kv_heads = key.shape[1]
82+
device = query.device
83+
is_decode = seq_len <= 1
84+
85+
# Reshape from HF [batch, heads, seq, dim] -> flat [batch*seq, heads, dim]
86+
q = query.permute(0, 2, 1, 3).reshape(batch * seq_len, num_heads, head_dim).contiguous()
87+
k = key.permute(0, 2, 1, 3).reshape(batch * seq_k, num_kv_heads, head_dim).contiguous()
88+
v = value.permute(0, 2, 1, 3).reshape(batch * seq_k, num_kv_heads, head_dim).contiguous()
89+
90+
# Build varlen metadata
91+
b_seq_len_q, has_padding = _seq_lens_from_mask(attention_mask, seq_len, device)
92+
if b_seq_len_q is None:
93+
b_seq_len_q = torch.full((batch,), seq_len, device=device, dtype=torch.int32)
94+
95+
kw = {
96+
"b_start_loc": torch.arange(batch, device=device, dtype=torch.int32) * seq_len,
97+
"b_seq_len": b_seq_len_q,
98+
"max_input_len": seq_len,
99+
"is_causal": not is_decode,
100+
"softmax_scale": scaling,
101+
}
102+
# Decode: Q has 1 token, K/V have seq_k tokens (KV cache, no padding)
103+
if is_decode:
104+
kw["b_start_loc_k"] = torch.arange(batch, device=device, dtype=torch.int32) * seq_k
105+
kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
106+
kw["max_input_len_k"] = seq_k
107+
108+
o = attention(q, k, v, **kw)
109+
110+
attn_output = o.view(batch, seq_len, num_heads, head_dim)
111+
112+
# Zero out padding positions (kernel produces NaN for all-padding rows due to 0/0).
113+
# Assumes right-padding (valid tokens at positions 0..n-1), which is the HF
114+
# convention during prefill. Left-padded inputs are not supported.
115+
if has_padding:
116+
pad_mask = torch.arange(seq_len, device=device)[None, :] >= b_seq_len_q[:, None]
117+
attn_output = attn_output.masked_fill(pad_mask[:, :, None, None], 0.0)
118+
119+
return (attn_output, None)
120+
121+
122+
def register_triton_attention() -> bool:
123+
"""Register the Triton backend with HF AttentionInterface.
124+
125+
Called by _set_attn_implementation() during sparsification. Must run before
126+
the model's first forward pass so HF dispatches to the Triton kernel.
127+
128+
Returns:
129+
True if registration succeeded, False if transformers API not available.
130+
"""
131+
try:
132+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
133+
except (ImportError, AttributeError):
134+
return False
135+
136+
ALL_ATTENTION_FUNCTIONS.register("modelopt_triton", triton_attention_forward)
137+
return True
138+
139+
140+
__all__ = [
141+
"register_triton_attention",
142+
"triton_attention_forward",
143+
]

0 commit comments

Comments
 (0)