Commit 8b01ba4
Skip softmax calibration via Triton kernel (#1597)
### What does this PR do?
Adds skip softmax calibration for LLMs via Triton kernel (leveraging
existing kernel used for diffusion)
Type of change: New feature
<!-- Details about the change. -->
### Usage
```
python hf_sa.py --pyt_ckpt_path Qwen/Qwen3-8B --sparse_attn skip_softmax_triton_calib
```
The Triton calibration equals PyTorch at every threshold, for both
phases:
| threshold | prefill triton/pytorch | decode triton/pytorch |
|------|------------------------|-----------------------|
| 0.30 | 0.0% / 0.0% | 12.5% / 12.5% |
| 0.50 | 0.0% / 0.0% | 37.5% / 37.5% |
| 0.70 | 10.0% / 10.0% | 62.5% / 62.5% |
### Before your PR is "*Ready for review*"
Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (`git commit -s -S`).
Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(...,
weights_only=False)`, `pickle`, etc.).
- Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain
why. -->
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A
<!--- Mandatory -->
- Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory
for new features or examples. -->
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes
or backward incompatible changes. -->
- Did you get Claude approval on this PR?: ✅ / ❌ / N/A <!--- Run
`/claude review`. NVIDIA org members can self-trigger for complex
changes; orthogonal to CodeRabbit. -->
### Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added a Triton-based skip-softmax sparse-attention calibration option
and a CLI flag to override the calibration data directory (defaults to
adjacent RULER data).
* **Bug Fixes**
* Ensure calibration kernels run on the correct CUDA device; align
measurement granularity and tile/block sizing; ignore padded query rows
when counting skippable tiles.
* **Tests**
* Added GPU Triton calibration tests for end-to-end inference,
multi-threshold stats, and decode-phase reporting.
* **Documentation**
* Updated changelog and example to expose the new option and flag.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
Signed-off-by: Kai Xu <kaix@nvidia.com>
Co-authored-by: Kai Xu <kaix@nvidia.com>1 parent e2c3e03 commit 8b01ba4
11 files changed
Lines changed: 553 additions & 140 deletions
File tree
- examples/llm_sparsity/attention_sparsity
- modelopt/torch
- kernels
- common/attention
- sparsity/attention
- sparsity/attention_sparsity
- calibration
- methods
- tests/gpu/torch
- kernels/sparsity/attention
- sparsity/attention_sparsity
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
61 | 61 | | |
62 | 62 | | |
63 | 63 | | |
| 64 | + | |
| 65 | + | |
64 | 66 | | |
65 | 67 | | |
66 | 68 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
| 34 | + | |
34 | 35 | | |
35 | 36 | | |
36 | 37 | | |
| |||
44 | 45 | | |
45 | 46 | | |
46 | 47 | | |
| 48 | + | |
47 | 49 | | |
48 | 50 | | |
49 | 51 | | |
| |||
186 | 188 | | |
187 | 189 | | |
188 | 190 | | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
189 | 200 | | |
190 | 201 | | |
191 | 202 | | |
| |||
302 | 313 | | |
303 | 314 | | |
304 | 315 | | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
305 | 324 | | |
306 | 325 | | |
307 | 326 | | |
Lines changed: 37 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
30 | 34 | | |
31 | 35 | | |
32 | 36 | | |
| |||
105 | 109 | | |
106 | 110 | | |
107 | 111 | | |
108 | | - | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
109 | 115 | | |
110 | 116 | | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
111 | 141 | | |
112 | 142 | | |
113 | 143 | | |
114 | 144 | | |
115 | 145 | | |
116 | 146 | | |
117 | 147 | | |
118 | | - | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
119 | 151 | | |
120 | | - | |
121 | | - | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
122 | 155 | | |
123 | 156 | | |
124 | 157 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
80 | 80 | | |
81 | 81 | | |
82 | 82 | | |
83 | | - | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
84 | 87 | | |
85 | 88 | | |
86 | 89 | | |
| |||
363 | 366 | | |
364 | 367 | | |
365 | 368 | | |
| 369 | + | |
| 370 | + | |
366 | 371 | | |
367 | 372 | | |
368 | 373 | | |
| |||
919 | 924 | | |
920 | 925 | | |
921 | 926 | | |
922 | | - | |
923 | | - | |
924 | | - | |
925 | | - | |
926 | | - | |
927 | | - | |
928 | | - | |
929 | | - | |
930 | | - | |
931 | | - | |
932 | | - | |
933 | | - | |
934 | | - | |
935 | | - | |
936 | | - | |
937 | | - | |
938 | | - | |
| 927 | + | |
| 928 | + | |
| 929 | + | |
| 930 | + | |
| 931 | + | |
| 932 | + | |
| 933 | + | |
| 934 | + | |
| 935 | + | |
| 936 | + | |
| 937 | + | |
| 938 | + | |
| 939 | + | |
| 940 | + | |
| 941 | + | |
| 942 | + | |
| 943 | + | |
| 944 | + | |
| 945 | + | |
| 946 | + | |
| 947 | + | |
| 948 | + | |
| 949 | + | |
939 | 950 | | |
940 | 951 | | |
941 | 952 | | |
| |||
970 | 981 | | |
971 | 982 | | |
972 | 983 | | |
| 984 | + | |
| 985 | + | |
| 986 | + | |
| 987 | + | |
| 988 | + | |
| 989 | + | |
973 | 990 | | |
974 | 991 | | |
975 | | - | |
976 | | - | |
977 | | - | |
978 | | - | |
979 | | - | |
980 | | - | |
981 | | - | |
982 | | - | |
983 | | - | |
984 | | - | |
985 | | - | |
986 | | - | |
987 | | - | |
988 | | - | |
989 | | - | |
| 992 | + | |
| 993 | + | |
| 994 | + | |
| 995 | + | |
| 996 | + | |
| 997 | + | |
| 998 | + | |
| 999 | + | |
| 1000 | + | |
| 1001 | + | |
| 1002 | + | |
| 1003 | + | |
| 1004 | + | |
| 1005 | + | |
| 1006 | + | |
| 1007 | + | |
990 | 1008 | | |
991 | 1009 | | |
992 | 1010 | | |
| |||
1016 | 1034 | | |
1017 | 1035 | | |
1018 | 1036 | | |
1019 | | - | |
1020 | | - | |
1021 | | - | |
1022 | | - | |
1023 | | - | |
1024 | | - | |
1025 | | - | |
1026 | | - | |
1027 | | - | |
1028 | | - | |
1029 | | - | |
1030 | | - | |
1031 | | - | |
1032 | | - | |
1033 | | - | |
1034 | | - | |
1035 | | - | |
1036 | | - | |
1037 | | - | |
1038 | | - | |
1039 | | - | |
1040 | | - | |
1041 | | - | |
1042 | | - | |
1043 | | - | |
1044 | | - | |
| 1037 | + | |
| 1038 | + | |
| 1039 | + | |
| 1040 | + | |
| 1041 | + | |
| 1042 | + | |
| 1043 | + | |
| 1044 | + | |
| 1045 | + | |
| 1046 | + | |
| 1047 | + | |
| 1048 | + | |
| 1049 | + | |
| 1050 | + | |
| 1051 | + | |
| 1052 | + | |
| 1053 | + | |
| 1054 | + | |
| 1055 | + | |
| 1056 | + | |
| 1057 | + | |
| 1058 | + | |
| 1059 | + | |
| 1060 | + | |
| 1061 | + | |
| 1062 | + | |
| 1063 | + | |
1045 | 1064 | | |
1046 | 1065 | | |
1047 | | - | |
1048 | | - | |
1049 | | - | |
1050 | | - | |
1051 | | - | |
1052 | | - | |
1053 | | - | |
1054 | | - | |
1055 | | - | |
1056 | | - | |
1057 | | - | |
1058 | | - | |
1059 | | - | |
1060 | | - | |
1061 | | - | |
1062 | | - | |
1063 | | - | |
1064 | | - | |
1065 | | - | |
1066 | | - | |
1067 | | - | |
1068 | | - | |
1069 | | - | |
| 1066 | + | |
| 1067 | + | |
| 1068 | + | |
| 1069 | + | |
| 1070 | + | |
| 1071 | + | |
| 1072 | + | |
| 1073 | + | |
| 1074 | + | |
| 1075 | + | |
| 1076 | + | |
| 1077 | + | |
| 1078 | + | |
| 1079 | + | |
| 1080 | + | |
| 1081 | + | |
| 1082 | + | |
| 1083 | + | |
| 1084 | + | |
| 1085 | + | |
| 1086 | + | |
| 1087 | + | |
| 1088 | + | |
| 1089 | + | |
1070 | 1090 | | |
1071 | 1091 | | |
1072 | 1092 | | |
| |||
0 commit comments