Skip to content

Commit 9a7ae2a

Browse files
committed
Address review feedbacks
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 5b22b85 commit 9a7ae2a

23 files changed

Lines changed: 226 additions & 773 deletions

File tree

examples/llm_sparsity/attention_sparsity/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Attention Sparsity for HuggingFace Models
22

3-
In this tutorial, we demonstrate how to use NVIDIA TensorRT Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation.
3+
In this tutorial, we demonstrate how to use NVIDIA Model Optimizer to apply attention sparsity to HuggingFace models. Attention sparsity reduces computational cost by skipping near-zero attention scores during the softmax computation.
44

55
## Getting Started
66

@@ -63,7 +63,7 @@ pip install nvidia-modelopt[hf]
6363
If using `SKIP_SOFTMAX_CALIB`, you need to download the RULER calibration dataset first:
6464

6565
```bash
66-
bash modelopt/torch/sparsity/attention_sparsity/calibration/download_ruler_data.sh
66+
bash ./download_ruler_data.sh
6767
```
6868

6969
This downloads the Paul Graham essays dataset used for generating calibration samples.
@@ -75,7 +75,7 @@ This downloads the Paul Graham essays dataset used for generating calibration sa
7575
Apply sparse attention with a fixed threshold:
7676

7777
```bash
78-
python examples/llm_sparsity/attention_sparsity/hf_sa.py \
78+
python hf_sa.py \
7979
--pyt_ckpt_path Qwen/Qwen3-8B \
8080
--sparse_attn skip_softmax
8181
```
@@ -85,7 +85,7 @@ python examples/llm_sparsity/attention_sparsity/hf_sa.py \
8585
Apply sparse attention with calibrated thresholds for optimal sparsity:
8686

8787
```bash
88-
python examples/llm_sparsity/attention_sparsity/hf_sa.py \
88+
python hf_sa.py \
8989
--pyt_ckpt_path Qwen/Qwen3-8B \
9090
--sparse_attn skip_softmax_calib
9191
```
@@ -121,7 +121,7 @@ The script automatically compares outputs before and after applying sparse atten
121121
Export the sparsified model to a HuggingFace checkpoint:
122122

123123
```bash
124-
python examples/llm_sparsity/attention_sparsity/hf_sa.py \
124+
python hf_sa.py \
125125
--pyt_ckpt_path Qwen/Qwen3-8B \
126126
--sparse_attn skip_softmax_calib \
127127
--export_dir ./exported_sparse_model
@@ -161,5 +161,5 @@ model = mtsa.sparsify(model, config=custom_config)
161161

162162
## References
163163

164-
- [TensorRT Model Optimizer Documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/)
164+
- [Model Optimizer Documentation](https://nvidia.github.io/Model-Optimizer/)
165165
- [RULER: What's the Real Context Size of Your Long-Context Language Models?](https://github.com/NVIDIA/RULER)

examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def main(args):
171171
print(f"\nApplying sparse attention: {args.sparse_attn}")
172172
sparse_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]
173173

174-
# Override target_sparse_ratio if provided via CLI
174+
# Override calibration options if provided via CLI
175175
if args.target_sparse_ratio is not None:
176176
sparse_config = copy.deepcopy(sparse_config)
177177
sparse_cfg = sparse_config.get("sparse_cfg", {})

modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Lines changed: 16 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515

1616
"""Calibration functions for sparse attention."""
1717

18-
import hashlib
19-
import json
2018
import warnings
2119
from collections.abc import Callable
22-
from pathlib import Path
2320
from typing import Any
2421

2522
import torch
@@ -28,59 +25,11 @@
2825

2926
from ..config import CalibrationConfig
3027
from ..conversion import print_sparse_attention_summary
31-
from ..sparse_attention import SparseAttentionModule
28+
from ..utils import get_named_sparse_attention_modules
3229
from .calibrator import DynamicThresholdCalibrator
3330
from .dataset import RulerDatasetBuilder
3431

3532

36-
def _get_cache_path(
37-
tokenizer_path: str, samples: int, max_seqlen: int, cache_dir: str | None = None
38-
) -> Path:
39-
"""Generate cache file path based on calibration parameters.
40-
41-
Args:
42-
tokenizer_path: Path to tokenizer (used in hash)
43-
samples: Number of calibration samples
44-
max_seqlen: Maximum sequence length
45-
cache_dir: Optional cache directory. If None, uses ~/.cache/modelopt/sparse_attention/
46-
"""
47-
# Create a hash of the parameters for the cache filename
48-
key = f"{tokenizer_path}_{samples}_{max_seqlen}"
49-
hash_str = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest()[:12]
50-
filename = f"ruler_cache_{samples}s_{max_seqlen}l_{hash_str}.json"
51-
52-
if cache_dir:
53-
base_dir = Path(cache_dir)
54-
else:
55-
base_dir = Path.home() / ".cache" / "modelopt" / "sparse_attention"
56-
57-
return base_dir / filename
58-
59-
60-
def _load_cached_data(cache_path: Path) -> list[dict[str, Any]] | None:
61-
"""Load calibration data from cache if it exists."""
62-
if cache_path.exists():
63-
try:
64-
with open(cache_path) as f:
65-
data = json.load(f)
66-
print(f"Loaded {len(data)} cached calibration samples from {cache_path}")
67-
return data
68-
except Exception as e:
69-
print(f"Warning: Failed to load cache: {e}")
70-
return None
71-
72-
73-
def _save_cached_data(cache_path: Path, data: list[dict[str, Any]]) -> None:
74-
"""Save calibration data to cache."""
75-
try:
76-
cache_path.parent.mkdir(parents=True, exist_ok=True)
77-
with open(cache_path, "w") as f:
78-
json.dump(data, f)
79-
print(f"Saved calibration samples to cache: {cache_path}")
80-
except Exception as e:
81-
print(f"Warning: Failed to save cache: {e}")
82-
83-
8433
def _extract_tokenizer_from_model(model: nn.Module) -> str:
8534
"""Extract tokenizer name/path from model config.
8635
@@ -152,7 +101,9 @@ def create_calibration_forward_loop(
152101
tokenizer.pad_token = tokenizer.eos_token
153102

154103
def forward_loop(model: nn.Module) -> None:
155-
device = next(model.parameters()).device
104+
from modelopt.torch.utils import get_module_device
105+
106+
device = get_module_device(model)
156107

157108
for sample in calibration_data:
158109
inputs = tokenizer(
@@ -210,7 +161,9 @@ def create_decode_calibration_forward_loop(
210161
tokenizer.pad_token = tokenizer.eos_token
211162

212163
def forward_loop(model: nn.Module) -> None:
213-
device = next(model.parameters()).device
164+
from modelopt.torch.utils import get_module_device
165+
166+
device = get_module_device(model)
214167

215168
for sample in calibration_data:
216169
inputs = tokenizer(
@@ -291,9 +244,7 @@ def calibrate_sparse_attention(
291244
return {}
292245

293246
# Get sparse attention modules
294-
sparse_modules = [
295-
(name, m) for name, m in model.named_modules() if isinstance(m, SparseAttentionModule)
296-
]
247+
sparse_modules = get_named_sparse_attention_modules(model)
297248

298249
if not sparse_modules:
299250
print("No sparse attention modules found for calibration")
@@ -306,29 +257,16 @@ def calibrate_sparse_attention(
306257
calibration_data = None
307258

308259
if calibrate_prefill or calibrate_decode:
309-
# Try to load from cache first
310-
cache_path = _get_cache_path(
311-
tokenizer,
312-
calib_config.samples,
313-
calib_config.max_seqlen,
260+
builder = RulerDatasetBuilder(
261+
samples=calib_config.samples,
262+
max_seqlen=calib_config.max_seqlen,
263+
tokenizer_name_or_path=tokenizer,
264+
num_length_bins=calib_config.num_length_bins,
265+
max_length_filter=int(calib_config.max_seqlen * 1.5),
314266
cache_dir=calib_config.cache_dir,
267+
data_dir=calib_config.data_dir,
315268
)
316-
calibration_data = _load_cached_data(cache_path)
317-
318-
# Generate if not cached
319-
if calibration_data is None:
320-
builder = RulerDatasetBuilder(
321-
samples=calib_config.samples,
322-
max_seqlen=calib_config.max_seqlen,
323-
tokenizer_name_or_path=tokenizer,
324-
num_length_bins=calib_config.num_length_bins,
325-
max_length_filter=int(calib_config.max_seqlen * 1.5),
326-
)
327-
calibration_data = builder.build_calibration_dataset()
328-
print(f"Generated {len(calibration_data)} calibration samples")
329-
330-
# Save to cache for future runs
331-
_save_cached_data(cache_path, calibration_data)
269+
calibration_data = builder.build_calibration_dataset()
332270

333271
# Initialize results
334272
calibration_results: dict[str, Any] = {}

modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from scipy.optimize import curve_fit
2727
from tqdm import tqdm
2828

29-
from ..sparse_attention import SparseAttentionModule
3029
from ..stats_manager import SparseAttentionStatsManager
30+
from ..utils import get_sparse_attention_modules
3131

3232

3333
class DynamicThresholdCalibrator:
@@ -113,7 +113,7 @@ def calibrate(self, model: nn.Module, forward_loop: Callable, phase: str) -> dic
113113
Dict with calibration results including a, b, r_squared, and num_data_points
114114
"""
115115
# Extract attention modules
116-
attention_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)]
116+
attention_modules = get_sparse_attention_modules(model)
117117

118118
if not attention_modules:
119119
raise ValueError("No sparse attention modules found for calibration")

modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515

1616
"""RULER dataset builder for sparse attention calibration."""
1717

18+
import hashlib
19+
import json
1820
import random
1921
import string
2022
from dataclasses import dataclass
23+
from pathlib import Path
2124
from typing import Any
2225

2326
from tqdm import tqdm
@@ -125,7 +128,7 @@ class RulerTask:
125128
),
126129
answer_prefix=(
127130
" Answer: According to the chain(s) of variable assignment in the text above, "
128-
"{num_v} variables are assgined the value {query}, they are: "
131+
"{num_v} variables are assigned the value {query}, they are: "
129132
),
130133
args={"num_chains": 1, "num_hops": 4},
131134
),
@@ -189,6 +192,8 @@ def __init__(
189192
num_length_bins: int = 4,
190193
max_length_filter: int = 65536,
191194
seed: int = 42,
195+
cache_dir: str | None = None,
196+
data_dir: str | Path | None = None,
192197
):
193198
"""Initialize RULER dataset builder.
194199
@@ -199,6 +204,9 @@ def __init__(
199204
seed: Random seed for reproducibility
200205
num_length_bins: Number of length bins to generate (default: 4)
201206
max_length_filter: Maximum sequence length to keep (default: 65536)
207+
cache_dir: Optional cache directory. If None, uses ~/.cache/modelopt/data/
208+
data_dir: Optional path to RULER data directory (contains 'essays' subdir).
209+
Required for NIAH tasks with essay haystack when not using pip default layout.
202210
203211
Note:
204212
Length bins are auto-generated as descending powers of 2:
@@ -220,6 +228,8 @@ def __init__(
220228
self.tokenizer_name_or_path = tokenizer_name_or_path
221229
self.seed = seed
222230
self.max_length_filter = max_length_filter
231+
self.cache_dir = cache_dir
232+
self.data_dir = Path(data_dir) if data_dir is not None else None
223233

224234
# Generate target lengths and validate
225235
self.target_lengths = _generate_target_lengths(max_seqlen, num_length_bins, min_seqlen=1024)
@@ -238,12 +248,58 @@ def __init__(
238248
self.tokenizer = tokenizer_name_or_path
239249
random.seed(seed)
240250

251+
def _get_cache_path(self) -> Path:
252+
"""Generate cache file path based on calibration parameters."""
253+
tokenizer_path = (
254+
self.tokenizer_name_or_path
255+
if isinstance(self.tokenizer_name_or_path, str)
256+
else str(self.tokenizer_name_or_path)
257+
)
258+
key = f"{tokenizer_path}_{self.total_samples}_{self.max_seqlen}"
259+
hash_str = hashlib.md5(key.encode(), usedforsecurity=False).hexdigest()[:12]
260+
filename = f"ruler_cache_{self.total_samples}s_{self.max_seqlen}l_{hash_str}.json"
261+
if self.cache_dir:
262+
base_dir = Path(self.cache_dir)
263+
else:
264+
base_dir = Path.home() / ".cache" / "modelopt" / "data"
265+
return base_dir / filename
266+
267+
def _load_cached_data(self, cache_path: Path) -> list[dict[str, Any]] | None:
268+
"""Load calibration data from cache if it exists."""
269+
if cache_path.exists():
270+
try:
271+
with open(cache_path) as f:
272+
data = json.load(f)
273+
print(f"Loaded {len(data)} cached calibration samples from {cache_path}")
274+
return data
275+
except Exception as e:
276+
print(f"Warning: Failed to load cache: {e}")
277+
return None
278+
279+
def _save_cached_data(self, cache_path: Path, data: list[dict[str, Any]]) -> None:
280+
"""Save calibration data to cache."""
281+
try:
282+
cache_path.parent.mkdir(parents=True, exist_ok=True)
283+
with open(cache_path, "w") as f:
284+
json.dump(data, f)
285+
print(f"Saved calibration samples to cache: {cache_path}")
286+
except Exception as e:
287+
print(f"Warning: Failed to save cache: {e}")
288+
241289
def build_calibration_dataset(self) -> list[dict[str, Any]]:
242290
"""Build the complete calibration dataset.
243291
292+
If cache_dir was set, checks cache first and returns cached data if present.
293+
Otherwise generates the dataset, saves to cache (if cache_dir set), and returns.
294+
244295
Returns:
245296
List of calibration samples with 'input' and 'length' fields
246297
"""
298+
cache_path = self._get_cache_path()
299+
cached = self._load_cached_data(cache_path)
300+
if cached is not None:
301+
return cached
302+
247303
all_samples = []
248304

249305
print(
@@ -265,6 +321,8 @@ def build_calibration_dataset(self) -> list[dict[str, Any]]:
265321

266322
random.shuffle(all_samples)
267323
print(f"Generated {len(all_samples)} valid samples")
324+
325+
self._save_cached_data(cache_path, all_samples)
268326
return all_samples
269327

270328
def _generate_sample(
@@ -312,6 +370,7 @@ def _generate_niah_sample(
312370
num_needle_k=args.get("num_needle_k", 1),
313371
num_needle_v=args.get("num_needle_v", 1),
314372
num_needle_q=args.get("num_needle_q", 1),
373+
data_dir=self.data_dir,
315374
)
316375

317376
# Generate sample using official RULER implementation
@@ -328,6 +387,7 @@ def _generate_niah_sample(
328387
num_needle_v=args.get("num_needle_v", 1),
329388
num_needle_q=args.get("num_needle_q", 1),
330389
random_seed=self.seed + sample_idx,
390+
data_dir=self.data_dir,
331391
)
332392

333393
# Add task metadata

0 commit comments

Comments
 (0)