Skip to content

Commit 012fb20

Browse files
committed
Add unit test for skip softmax
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 9c0de6c commit 012fb20

2 files changed

Lines changed: 46 additions & 1 deletion

File tree

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ NVIDIA Model Optimizer Changelog
77
**New Features**
88

99
- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
10-
- Add skip-softmax tile skipping to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). KV tiles with negligible attention scores are skipped entirely during prefill, saving V loads and computation on long sequences with strong attention locality. Integrates with ``mtsa.sparsify()`` via the ``triton_skip_softmax`` method.
10+
- Add skip-softmax tile skipping to the Triton flash attention kernel (``modelopt.torch.kernels.triton_fa``). Integrates with the ``mtsa.sparsify()`` API.
1111

1212
**Bug Fixes**
1313

tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,48 @@ def test_decode_single_token(self):
553553
skip_softmax_threshold=1e-3,
554554
)
555555
torch.testing.assert_close(out_skip, out_dense, rtol=5e-2, atol=5e-2)
556+
557+
558+
@pytest.mark.skipif(not TRITON_KERNEL_AVAILABLE, reason="Need CUDA + triton")
559+
class TestSkipSoftmaxHFIntegration:
560+
"""HF integration for skip-softmax via mtsa.sparsify()."""
561+
562+
def test_skip_softmax_via_sparsify(self, tiny_llama_dir):
563+
"""mtsa.sparsify() with triton_skip_softmax produces finite logits."""
564+
pytest.importorskip("transformers")
565+
from transformers import AutoModelForCausalLM, AutoTokenizer
566+
567+
import modelopt.torch.sparsity.attention_sparsity as mtsa
568+
569+
tok = AutoTokenizer.from_pretrained(tiny_llama_dir)
570+
if tok.pad_token_id is None:
571+
tok.pad_token_id = tok.eos_token_id
572+
ids = torch.randint(1, tok.vocab_size, (1, 64), device="cuda")
573+
574+
# Dense baseline (triton backend, no skip)
575+
model_dense = AutoModelForCausalLM.from_pretrained(
576+
tiny_llama_dir,
577+
attn_implementation="modelopt_triton",
578+
torch_dtype=torch.bfloat16,
579+
device_map="cuda",
580+
)
581+
model_dense.eval()
582+
with torch.no_grad():
583+
logits_dense = model_dense(input_ids=ids).logits
584+
del model_dense
585+
586+
# Skip-softmax via mtsa.sparsify()
587+
model_skip = AutoModelForCausalLM.from_pretrained(
588+
tiny_llama_dir,
589+
torch_dtype=torch.bfloat16,
590+
device_map="cuda",
591+
)
592+
mtsa.sparsify(model_skip, mtsa.SKIP_SOFTMAX_TRITON_DEFAULT)
593+
model_skip.eval()
594+
with torch.no_grad():
595+
logits_skip = model_skip(input_ids=ids).logits
596+
597+
assert not torch.isnan(logits_skip).any(), "NaN in skip-softmax logits"
598+
assert not torch.isinf(logits_skip).any(), "Inf in skip-softmax logits"
599+
# On short sequences (64 tokens), no tiles are skipped — output should match dense
600+
torch.testing.assert_close(logits_skip, logits_dense, rtol=1e-3, atol=1e-3)

0 commit comments

Comments
 (0)