Skip to content

Commit 9e38041

Browse files
authored
[OMNIML-2850] [3/n] Adds sparse attention calibration (#538)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> new feature **Overview:** ? - This PR adds the sparse attention calibration algorithm - Chunked prefill to support long ctx_len - Separated calibration for prefill and decode ## Usage <!-- You can potentially add a usage example below. --> ```python import modelopt.torch.sparsity.attention_sparsity as mtsa # Apply sparse attention with calibration model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB) # Print summary - now shows actual thresholds mtsa.print_sparse_attention_summary(model) # Output: # Method: flash_skip_softmax, Threshold: Dynamic (λ=437.395926) # Or llm_eval integration # HuggingFace sparse attention example python examples/llm_sparsity/attention_sparsity/hf_sa.py \ --pyt_ckpt_path Qwen/Qwen3-4B \ --sparse_attn skip_softmax_calib ``` # The calibration method ## Calibration Algorithm - Implemented the Inverse Power model: scale_factor = k / (1 - sparsity)^p - Fit model parameters (k, p) per phase using scipy.optimize.curve_fit - At inference: threshold = k / (1 - target_sparsity)^p / seqlen ## Why Choosing the Inverse Power model? The inverse power model better fits the relationship between sparsity ratio and threshold_scale_factor. <img width="2388" height="1082" alt="sparsity_model_analysis" src="https://github.com/user-attachments/assets/4dfb45d4-8c16-4f15-a878-c8e08a9b6128" /> ## Runtime Flexibility - Target sparsity can be changed at inference time without recalibration - Users can adjust module._sparse_method_instance.target_sparse_ratio dynamically - Threshold automatically adapts to sequence length ## Testing <!-- Mention how have you tested your change if applicable. --> The calibration results for `Qwen/Qwen3-30B-A3B-Thinking-2507` are shown below and are mostly consistent with the ground-truth numbers collected from the kernel side. ``` Prefill Calibration Results: Model: scale_factor = k / (1 - sparsity)^p Fitted k: 1003.3990 Fitted p: 1.2589 R-squared: 0.827549 Scale factors for different target sparsities: Target Scale Factor ---------- --------------- 50% 2401.35 70% 4568.26 80% 7610.98 90% 18214.70 95% 43591.65 ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 3801923 commit 9e38041

38 files changed

Lines changed: 4740 additions & 423 deletions

CHANGELOG.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
NVIDIA Model Optimizer Changelog (Linux)
22
========================================
33

4+
0.43 (2026-03-xx)
5+
^^^^^^^^^^^^^^^^^
6+
7+
**New Features**
8+
9+
- Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
10+
411
0.42 (2026-02-xx)
512
^^^^^^^^^^^^^^^^^
613

examples/llm_eval/lm_eval_hf.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@
4343
from lm_eval.api.model import T
4444
from lm_eval.models.huggingface import HFLM
4545
from quantization_utils import quantize_model
46+
from sparse_attention_utils import sparsify_model
4647

4748
import modelopt.torch.opt as mto
4849
from modelopt.torch.quantization.utils import is_quantized
50+
from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified
4951

5052

5153
def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T:
@@ -60,6 +62,9 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
6062
calib_size = arg_dict.pop("calib_size", 512)
6163
compress = arg_dict.pop("compress", False)
6264

65+
# Sparse attention arguments
66+
sparse_cfg = arg_dict.pop("sparse_cfg", None)
67+
6368
additional_config = {} if additional_config is None else additional_config
6469
additional_config = {k: v for k, v in additional_config.items() if v is not None}
6570

@@ -91,6 +96,15 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
9196
auto_quantize_checkpoint=auto_quantize_checkpoint,
9297
)
9398

99+
if sparse_cfg:
100+
if is_attn_sparsified(model_obj.model):
101+
warnings.warn("Skipping sparse attention: model already has sparse attention applied.")
102+
else:
103+
sparsify_model(
104+
model=model_obj,
105+
sparse_cfg=sparse_cfg,
106+
)
107+
94108
return model_obj
95109

96110

@@ -152,6 +166,11 @@ def setup_parser_with_modelopt_args():
152166
action="store_true",
153167
help="Compress the model after quantization",
154168
)
169+
parser.add_argument(
170+
"--sparse_cfg",
171+
type=str,
172+
help="Sparse attention configuration (e.g., SKIP_SOFTMAX_DEFAULT, SKIP_SOFTMAX_CALIB)",
173+
)
155174
return parser
156175

157176

@@ -177,6 +196,7 @@ def setup_parser_with_modelopt_args():
177196
"calib_batch_size": args.calib_batch_size,
178197
"calib_size": args.calib_size,
179198
"compress": args.compress,
199+
"sparse_cfg": args.sparse_cfg,
180200
}
181201
)
182202

examples/llm_eval/mmlu.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from fire import Fire
4949
from modeling import EvalModel, select_model
5050
from quantization_utils import MAX_SEQ_LEN, get_tokenizer, quantize_model
51+
from sparse_attention_utils import sparsify_model
5152
from tqdm import tqdm
5253

5354
try:
@@ -56,6 +57,7 @@
5657
LLM = None # type: ignore[misc]
5758
import modelopt.torch.opt as mto
5859
from modelopt.torch.quantization.utils import is_quantized
60+
from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified
5961

6062
os.environ["TOKENIZERS_PARALLELISM"] = "false"
6163

@@ -230,6 +232,7 @@ def main(
230232
auto_quantize_method: str = "gradient",
231233
auto_quantize_score_size: int = 128,
232234
auto_quantize_checkpoint: str | None = None,
235+
sparse_cfg: str | None = None,
233236
**kwargs,
234237
):
235238
random.seed(RAND_SEED)
@@ -289,6 +292,20 @@ def main(
289292
auto_quantize_checkpoint=auto_quantize_checkpoint,
290293
)
291294

295+
# Apply sparse attention if requested
296+
if sparse_cfg:
297+
model.load()
298+
299+
if is_attn_sparsified(model.model):
300+
warnings.warn(
301+
"Skipping sparse attention: model already has sparse attention applied."
302+
)
303+
else:
304+
sparsify_model(
305+
model=model,
306+
sparse_cfg=sparse_cfg,
307+
)
308+
292309
for subject in tqdm(subjects):
293310
dev_df = pd.read_csv(os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None)[
294311
:ntrain

examples/llm_eval/modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class SeqToSeqModel(EvalModel):
179179
lora_path: str = ""
180180
device: str = "cuda"
181181
load_8bit: bool = False
182+
attn_implementation: str | None = None
182183

183184
def load(self):
184185
if self.model is None:
@@ -188,6 +189,8 @@ def load(self):
188189
if self.load_8bit:
189190
args.update(device_map="auto", load_in_8bit=True)
190191
args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
192+
if self.attn_implementation:
193+
args["attn_implementation"] = self.attn_implementation
191194
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path, **args)
192195
print_gpu_utilization()
193196
if self.lora_path:
@@ -241,6 +244,8 @@ def load(self):
241244
if self.load_8bit:
242245
args.update(device_map="auto", load_in_8bit=True)
243246
args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
247+
if self.attn_implementation:
248+
args["attn_implementation"] = self.attn_implementation
244249
self.model = AutoModelForCausalLM.from_pretrained(
245250
self.model_path, trust_remote_code=True, **args
246251
)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Utilities for sparse attention integration with llm_eval."""
17+
18+
import modelopt.torch.sparsity.attention_sparsity as mtsa
19+
20+
21+
def _extract_model(model_obj):
22+
"""Extract actual model from wrapper (HFLM or EvalModel)."""
23+
if hasattr(model_obj, "gpt2"):
24+
return model_obj.gpt2
25+
elif hasattr(model_obj, "model"):
26+
return model_obj.model
27+
else:
28+
return model_obj
29+
30+
31+
def sparsify_model(
32+
model,
33+
sparse_cfg: str,
34+
backend=None,
35+
):
36+
"""Apply sparse attention to model with optional RULER calibration.
37+
38+
Args:
39+
model: Model wrapper (HFLM or EvalModel) or raw model
40+
sparse_cfg: Sparse attention config name or dict
41+
backend: Backend to use (optional, overrides config backend)
42+
43+
Returns:
44+
The model with sparse attention applied
45+
46+
Note:
47+
Calibration is automatically triggered if the config contains a 'calibration' field.
48+
The calibration will auto-generate RULER dataset from the model's tokenizer.
49+
"""
50+
# Extract actual model
51+
net = _extract_model(model)
52+
53+
# Resolve config
54+
if isinstance(sparse_cfg, str):
55+
# Get config from mtsa module (e.g., SKIP_SOFTMAX_CALIB, SKIP_SOFTMAX_DEFAULT)
56+
mtsa_cfg = getattr(mtsa, sparse_cfg, None)
57+
if mtsa_cfg is None:
58+
raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}.")
59+
else:
60+
mtsa_cfg = sparse_cfg
61+
62+
# Override backend if specified
63+
if backend:
64+
if isinstance(mtsa_cfg, dict) and "sparse_cfg" in mtsa_cfg:
65+
modified_sparse_cfg = {}
66+
for pattern, cfg in mtsa_cfg["sparse_cfg"].items():
67+
modified_cfg = cfg.copy() if isinstance(cfg, dict) else cfg
68+
if isinstance(modified_cfg, dict):
69+
modified_cfg["backend"] = backend
70+
modified_sparse_cfg[pattern] = modified_cfg
71+
mtsa_cfg = {"sparse_cfg": modified_sparse_cfg}
72+
73+
# Apply sparsification
74+
print(f"\nApplying sparse attention with config: {sparse_cfg}")
75+
mtsa.sparsify(net, mtsa_cfg)
76+
print("Sparse attention applied successfully!")
77+
78+
return model
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Data directory for calibration
2+
data
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Attention Sparsity for HuggingFace Models
2+
3+
In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation.
4+
5+
## Getting Started
6+
7+
### Quick Example
8+
9+
```python
10+
import modelopt.torch.sparsity.attention_sparsity as mtsa
11+
from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT
12+
13+
# Load your model
14+
model = AutoModelForCausalLM.from_pretrained(
15+
"Qwen/Qwen3-8B",
16+
attn_implementation="eager", # Required for sparse attention
17+
torch_dtype=torch.bfloat16,
18+
)
19+
20+
# Apply sparse attention
21+
model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT)
22+
```
23+
24+
> [!Note]
25+
> `attn_implementation="eager"` is required for sparse attention to work properly. Flash Attention 2 or SDPA would bypass the softmax patching needed for stats collection.
26+
27+
## Configuration Options
28+
29+
Two pre-defined configurations are available:
30+
31+
### 1. Fixed Threshold (SKIP_SOFTMAX_DEFAULT)
32+
33+
Uses a fixed threshold value. Simple but may not be optimal for all sequence lengths.
34+
35+
```python
36+
from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_DEFAULT
37+
38+
model = mtsa.sparsify(model, config=SKIP_SOFTMAX_DEFAULT)
39+
```
40+
41+
### 2. Calibrated Threshold (SKIP_SOFTMAX_CALIB)
42+
43+
Uses RULER-based calibration to determine an optimal dynamic threshold that adapts to sequence length. Recommended for production use.
44+
45+
```python
46+
from modelopt.torch.sparsity.attention_sparsity.config import SKIP_SOFTMAX_CALIB
47+
48+
model = mtsa.sparsify(model, config=SKIP_SOFTMAX_CALIB)
49+
```
50+
51+
## Prerequisites
52+
53+
### Local Installation
54+
55+
For Hugging Face models, install Model Optimizer with `hf` dependencies using `pip` from [PyPI](https://pypi.org/project/nvidia-modelopt/) and install the requirements for the example:
56+
57+
```bash
58+
pip install nvidia-modelopt[hf]
59+
```
60+
61+
### Download RULER Calibration Data (Required for Calibration)
62+
63+
If using `SKIP_SOFTMAX_CALIB`, you need to download the RULER calibration dataset first:
64+
65+
```bash
66+
bash ./download_ruler_data.sh
67+
```
68+
69+
This downloads the Paul Graham essays dataset used for generating calibration samples.
70+
71+
## Run Sparse Attention on HuggingFace Models
72+
73+
### Basic Usage (Without Calibration)
74+
75+
Apply sparse attention with a fixed threshold:
76+
77+
```bash
78+
python hf_sa.py \
79+
--pyt_ckpt_path Qwen/Qwen3-8B \
80+
--sparse_attn skip_softmax
81+
```
82+
83+
### With RULER Calibration
84+
85+
Apply sparse attention with calibrated thresholds for optimal sparsity:
86+
87+
```bash
88+
python hf_sa.py \
89+
--pyt_ckpt_path Qwen/Qwen3-8B \
90+
--sparse_attn skip_softmax_calib
91+
```
92+
93+
The calibration process:
94+
95+
1. Generates RULER calibration samples
96+
2. Collects attention statistics during forward passes
97+
3. Determines optimal threshold scale factor for target sparsity ratio
98+
99+
### Command Line Arguments
100+
101+
| Argument | Default | Description |
102+
|----------|---------|-------------|
103+
| `--pyt_ckpt_path` | Required | HuggingFace model path or name |
104+
| `--sparse_attn` | `skip_softmax` | Configuration: `skip_softmax` or `skip_softmax_calib` |
105+
| `--backend` | `pytorch` | Backend: `pytorch` (only supported backend) |
106+
| `--seq_len` | `2048` | Maximum sequence length for input prompts |
107+
| `--export_dir` | `None` | Directory to export the sparsified model |
108+
109+
## Output Comparison
110+
111+
The script automatically compares outputs before and after applying sparse attention:
112+
113+
1. Loads a test sample from the NarrativeQA dataset
114+
2. Generates text before sparse attention is applied
115+
3. Applies sparse attention (with optional calibration)
116+
4. Generates text after sparse attention is applied
117+
5. Compares and displays both outputs
118+
119+
## Export Model
120+
121+
Export the sparsified model to a HuggingFace checkpoint:
122+
123+
```bash
124+
python hf_sa.py \
125+
--pyt_ckpt_path Qwen/Qwen3-8B \
126+
--sparse_attn skip_softmax_calib \
127+
--export_dir ./exported_sparse_model
128+
```
129+
130+
The exported model can be loaded and used with standard HuggingFace APIs.
131+
132+
## Custom Configuration
133+
134+
You can create custom sparse attention configurations:
135+
136+
```python
137+
custom_config = {
138+
"sparse_cfg": {
139+
"calibration": { # Optional: omit for fixed threshold
140+
"target_sparse_ratio": {"prefill": 0.5, "decode": 0.5}, # Target 50% sparsity
141+
"samples": 128, # Number of calibration samples
142+
"max_seqlen": 8192, # Maximum sequence length
143+
# Optional: customize threshold trials for calibration
144+
"threshold_trials": [1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1],
145+
},
146+
"*attn*": { # Pattern to match attention modules
147+
"method": "flash_skip_softmax",
148+
"threshold": {"prefill": 1e-3, "decode": 1e-4}, # Phase-specific thresholds (ignored if calibration is used)
149+
"br": 128, # Flash Attention block rows
150+
"bc": 128, # Flash Attention block columns
151+
"backend": "pytorch",
152+
"collect_stats": True,
153+
"enable": True,
154+
},
155+
"default": {"enable": False},
156+
},
157+
}
158+
159+
model = mtsa.sparsify(model, config=custom_config)
160+
```
161+
162+
## References
163+
164+
- [Model Optimizer Documentation](https://nvidia.github.io/Model-Optimizer/)
165+
- [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER)

0 commit comments

Comments
 (0)