Commit 2bad66c
authored
[3/n] Add skip-softmax to Triton flash attention kernel (#1081)
### What does this PR do?
Type of change: ? <!-- Use one of the following: Bug fix, new feature,
new example, new tests, documentation. -->
<!-- Details about the change. -->
New feature. Add skip-softmax tile skipping to the Triton flash
attention kernel.
### Usage
```python
# Add a code snippet demonstrating how to use this
from modelopt.torch.kernels import attention
# Skip-softmax with threshold 0.1 (tiles contributing < 10% are skipped)
out = attention(q, k, v, b_start_loc, b_seq_len, max_len,
skip_softmax_threshold=0.1)
# Via mtsa.sparsify() on HuggingFace models
import modelopt.torch.sparsity.attention_sparsity as mtsa
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B",
torch_dtype=torch.bfloat16,
device_map="cuda")
# Default config
mtsa.sparsify(model, mtsa.SKIP_SOFTMAX_TRITON_DEFAULT)
```
### Testing
<!-- Mention how have you tested your change if applicable. -->
Performance (TFLOPS at seq_len=16384, RTX 6000 Pro):
| SEQ_LEN | ModelOpt Triton | PyTorch SDPA | Flash Attention 2 |
Skip-Softmax t=0.01 | Skip-Softmax t=0.1 |
|---:|---:|---:|---:|---:|---:|
| 16384.0 | 188.849922 | 211.718193 | 224.242843 | 172.901804 |
279.861684 |
| 32768.0 | 175.321787 | 212.815740 | 224.833553 | 146.150702 |
262.490463 |
| 65536.0 | 167.302839 | 214.932407 | 226.456141 | 145.082937 |
243.344791 |
</body></html>
### Before your PR is "*Ready for review*"
Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (`git commit -s -S`).
Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(...,
weights_only=False)`, `pickle`, etc.).
- Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain
why. -->
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A
<!--- Mandatory -->
- Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory
for new features or examples. -->
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes
or backward incompatible changes. -->
### Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added a Triton "skip-softmax" tile-skipping option for flash attention
with a new attention keyword and configurable threshold (default 0.1).
* Added a new sparse attention method and a default sparse configuration
that enables the Triton skip-softmax method.
* **Tests**
* Added GPU tests covering threshold behavior, numerical fidelity vs
dense, shape preservation, decode-mode, and integration with sparsify.
* **Documentation**
* Updated changelog for the new feature and removed two prior listed
entries.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Kai Xu <kaix@nvidia.com>1 parent b1f9f01 commit 2bad66c
File tree
10 files changed
+602
-110
lines changed- modelopt/torch
- kernels
- sparsity/attention_sparsity
- methods
- tests/gpu/torch/sparsity/attention_sparsity
10 files changed
+602
-110
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| 11 | + | |
11 | 12 | | |
12 | 13 | | |
13 | 14 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
105 | 105 | | |
106 | 106 | | |
107 | 107 | | |
108 | | - | |
109 | | - | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | | - | |
114 | | - | |
115 | | - | |
116 | | - | |
117 | | - | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
118 | 122 | | |
119 | 123 | | |
120 | 124 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
| 26 | + | |
| 27 | + | |
26 | 28 | | |
27 | 29 | | |
28 | 30 | | |
| |||
248 | 250 | | |
249 | 251 | | |
250 | 252 | | |
| 253 | + | |
| 254 | + | |
251 | 255 | | |
252 | 256 | | |
253 | 257 | | |
| |||
320 | 324 | | |
321 | 325 | | |
322 | 326 | | |
323 | | - | |
324 | | - | |
325 | | - | |
326 | | - | |
327 | | - | |
328 | | - | |
329 | | - | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | | - | |
336 | | - | |
337 | | - | |
338 | | - | |
339 | | - | |
340 | | - | |
341 | | - | |
342 | | - | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
343 | 386 | | |
344 | 387 | | |
345 | 388 | | |
| |||
440 | 483 | | |
441 | 484 | | |
442 | 485 | | |
| 486 | + | |
| 487 | + | |
443 | 488 | | |
444 | 489 | | |
445 | 490 | | |
| |||
523 | 568 | | |
524 | 569 | | |
525 | 570 | | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
526 | 581 | | |
527 | 582 | | |
528 | 583 | | |
| |||
574 | 629 | | |
575 | 630 | | |
576 | 631 | | |
| 632 | + | |
| 633 | + | |
577 | 634 | | |
578 | 635 | | |
579 | 636 | | |
| |||
665 | 722 | | |
666 | 723 | | |
667 | 724 | | |
| 725 | + | |
| 726 | + | |
| 727 | + | |
| 728 | + | |
| 729 | + | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
| 733 | + | |
| 734 | + | |
668 | 735 | | |
669 | 736 | | |
670 | 737 | | |
| |||
700 | 767 | | |
701 | 768 | | |
702 | 769 | | |
| 770 | + | |
703 | 771 | | |
704 | 772 | | |
705 | 773 | | |
| |||
720 | 788 | | |
721 | 789 | | |
722 | 790 | | |
| 791 | + | |
| 792 | + | |
| 793 | + | |
| 794 | + | |
| 795 | + | |
| 796 | + | |
| 797 | + | |
| 798 | + | |
| 799 | + | |
| 800 | + | |
| 801 | + | |
723 | 802 | | |
724 | 803 | | |
725 | 804 | | |
| |||
758 | 837 | | |
759 | 838 | | |
760 | 839 | | |
| 840 | + | |
| 841 | + | |
761 | 842 | | |
762 | 843 | | |
763 | 844 | | |
| |||
776 | 857 | | |
777 | 858 | | |
778 | 859 | | |
| 860 | + | |
| 861 | + | |
779 | 862 | | |
780 | 863 | | |
781 | 864 | | |
| |||
854 | 937 | | |
855 | 938 | | |
856 | 939 | | |
| 940 | + | |
| 941 | + | |
857 | 942 | | |
858 | 943 | | |
859 | 944 | | |
| |||
877 | 962 | | |
878 | 963 | | |
879 | 964 | | |
| 965 | + | |
| 966 | + | |
880 | 967 | | |
881 | 968 | | |
882 | 969 | | |
883 | 970 | | |
884 | | - | |
| 971 | + | |
| 972 | + | |
| 973 | + | |
| 974 | + | |
| 975 | + | |
| 976 | + | |
| 977 | + | |
| 978 | + | |
| 979 | + | |
| 980 | + | |
| 981 | + | |
| 982 | + | |
| 983 | + | |
| 984 | + | |
| 985 | + | |
| 986 | + | |
| 987 | + | |
| 988 | + | |
885 | 989 | | |
886 | 990 | | |
887 | 991 | | |
| |||
901 | 1005 | | |
902 | 1006 | | |
903 | 1007 | | |
| 1008 | + | |
904 | 1009 | | |
905 | | - | |
| 1010 | + | |
906 | 1011 | | |
907 | 1012 | | |
908 | 1013 | | |
| |||
926 | 1031 | | |
927 | 1032 | | |
928 | 1033 | | |
| 1034 | + | |
| 1035 | + | |
| 1036 | + | |
| 1037 | + | |
| 1038 | + | |
| 1039 | + | |
929 | 1040 | | |
930 | 1041 | | |
931 | 1042 | | |
| |||
947 | 1058 | | |
948 | 1059 | | |
949 | 1060 | | |
| 1061 | + | |
950 | 1062 | | |
951 | 1063 | | |
952 | 1064 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
129 | 129 | | |
130 | 130 | | |
131 | 131 | | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
132 | 142 | | |
133 | 143 | | |
134 | 144 | | |
| |||
528 | 538 | | |
529 | 539 | | |
530 | 540 | | |
| 541 | + | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
531 | 555 | | |
532 | 556 | | |
533 | 557 | | |
| 558 | + | |
534 | 559 | | |
535 | 560 | | |
536 | 561 | | |
| |||
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
27 | | - | |
| 27 | + | |
Lines changed: 1 addition & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
76 | 76 | | |
77 | 77 | | |
78 | 78 | | |
79 | | - | |
80 | | - | |
81 | | - | |
82 | | - | |
| 79 | + | |
83 | 80 | | |
84 | 81 | | |
85 | 82 | | |
| |||
0 commit comments