Skip to content

Skip softmax calibration via Triton kernel#1597

Open
rohansjoshi wants to merge 1 commit into
mainfrom
rohjoshi/triton-ss-calib
Open

Skip softmax calibration via Triton kernel#1597
rohansjoshi wants to merge 1 commit into
mainfrom
rohjoshi/triton-ss-calib

Conversation

@rohansjoshi
Copy link
Copy Markdown
Contributor

@rohansjoshi rohansjoshi commented Jun 2, 2026

What does this PR do?

Adds skip softmax calibration for LLMs via Triton kernel (leveraging existing kernel used for diffusion)

Type of change: New feature

Usage

python hf_sa.py --pyt_ckpt_path Qwen/Qwen3-8B --sparse_attn skip_softmax_triton_calib

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added new sparse-attention configuration option enabling Triton-based skip-softmax calibration.
    • Added CLI argument --calib_data_dir to customize calibration data directory paths.
    • Enhanced calibration handling with automatic default data directory configuration for RULER methodology.
  • Tests

    • Added comprehensive GPU-based tests for Triton calibration validation and inference accuracy verification.

Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
@rohansjoshi rohansjoshi requested review from a team as code owners June 2, 2026 00:53
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 2, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR adds HuggingFace Triton skip-softmax calibration support by introducing thread-local calibration state management in the HF attention backend, defining a prefill-focused configuration, integrating HF Triton into the existing calibration orchestration, extending the CLI example with data directory handling, and validating the implementation with GPU tests.

Changes

HF Triton Skip-Softmax Calibration Support

Layer / File(s) Summary
HF Triton backend calibration infrastructure
modelopt/torch/kernels/common/attention/hf_triton_attention.py
Introduces thread-local state for calibration config, threshold trials, and counters. Routes prefill calls through attention_calibrate during calibration mode, accumulates counters across forward passes, and exports accessor functions get_calibration_counters and get_calibration_seq_k.
Prefill calibration configuration constant
modelopt/torch/sparsity/attention_sparsity/config.py
Defines SKIP_SOFTMAX_TRITON_CALIB dict with prefill-only targeting, method="triton_skip_softmax", backend="triton", and full-sequence chunking (chunk_size=-1).
TritonSkipSoftmaxMethod HF Triton wiring
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
Extends _set_triton_backends, _clear_triton_backends, and _collect_calibration_stats to configure, clear, and read HF Triton backend state alongside existing diffusers and LTX backends.
Example script calibration CLI and data defaults
examples/llm_sparsity/attention_sparsity/hf_sa.py
Registers skip_softmax_triton_calib in configuration choices, defaults calibration data directory to local data folder, and adds --calib_data_dir CLI argument for user override.
GPU tests for calibration flow and counter mechanics
tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py
Validates end-to-end calibration on tiny Llama model (parameter bounds, inference stability, no NaN outputs) and tests counter monotonicity across thresholds in HF Triton backend.

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding skip softmax calibration via Triton kernel, which is the primary feature introduced across all modified files.
Docstring Coverage ✅ Passed Docstring coverage is 88.89% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found: no unsafe torch.load, numpy.load, hardcoded trust_remote_code=True, eval/exec, # nosec, or new non-permissive dependencies per SECURITY.md.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch rohjoshi/triton-ss-calib

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 2, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1597/

Built to branch gh-pages at 2026-06-02 00:58 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/llm_sparsity/attention_sparsity/hf_sa.py (1)

176-182: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fix --target_sparse_ratio override to avoid enabling decode calibration when Triton skip-softmax is prefill-only.
calibration/calibrate.py gates decode calibration on target_sparse_ratio["decode"] > 0 (via calibrate_decode = target_dict.get("decode", 0.0) > 0.0), so the current override in examples/llm_sparsity/attention_sparsity/hf_sa.py will force decode to be enabled even when the selected Triton calibration config is meant to be prefill-only (SKIP_SOFTMAX_TRITON_CALIB). Preserve the existing target_sparse_ratio phase set instead of unconditionally writing both phases.

🔧 Proposed fix to only override phases already present
     if args.target_sparse_ratio is not None:
         calib = sparse_cfg.setdefault("calibration", {})
         assert isinstance(calib, dict)
-        calib["target_sparse_ratio"] = {
-            "prefill": args.target_sparse_ratio,
-            "decode": args.target_sparse_ratio,
-        }
+        existing = calib.get("target_sparse_ratio", {"prefill": 0.5, "decode": 0.5})
+        calib["target_sparse_ratio"] = {
+            phase: args.target_sparse_ratio for phase in existing
+        }
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_sparsity/attention_sparsity/hf_sa.py` around lines 176 - 182,
The current override always writes both "prefill" and "decode" into
calib["target_sparse_ratio"], which forces decode calibration on even for
prefill-only configs; instead, fetch the existing target dict (e.g., target =
calib.get("target_sparse_ratio", {})) and only update the phases that are
already present in that dict (or at least only set "prefill" if "decode" is not
present) when applying args.target_sparse_ratio; in practice modify the block
that uses sparse_cfg.setdefault("calibration", {}) so it merges
args.target_sparse_ratio into the existing target dict key-by-key (updating only
existing phase keys like "prefill" and "decode") rather than unconditionally
writing both phases.
🧹 Nitpick comments (3)
modelopt/torch/kernels/common/attention/hf_triton_attention.py (1)

176-176: 💤 Low value

Add a brief comment explaining the deferred import.

Per CONTRIBUTING.md, function-level imports should include a brief comment naming the reason (e.g., lazy/optional/circular). This import is deferred to the calibration path—add a comment such as # Lazy import: only needed during calibration.

📝 Suggested comment
-        from modelopt.torch.kernels.common.attention import attention_calibrate
+        # Lazy import: only needed during calibration.
+        from modelopt.torch.kernels.common.attention import attention_calibrate
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/kernels/common/attention/hf_triton_attention.py` at line 176,
Add a short explanatory comment for the deferred import of attention_calibrate
to indicate it's performed lazily for calibration only; locate the import
statement "from modelopt.torch.kernels.common.attention import
attention_calibrate" in hf_triton_attention.py and add a brief comment such as
"# Lazy import: only needed during calibration" immediately above or inline with
that import.
tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py (2)

114-114: 💤 Low value

Prefer math.isfinite(a) over wrapping a Python scalar in a tensor.

a is already a Python float, so torch.isfinite(torch.tensor(a)) allocates a tensor unnecessarily. math.isfinite(a) is clearer.

♻️ Suggested change
-            assert a > 0 and torch.isfinite(torch.tensor(a))
+            assert a > 0 and math.isfinite(a)

(add import math at the top of the file)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`
at line 114, Replace the tensor-based finiteness check with Python's
math.isfinite: add "import math" at the top of the test file and change the
assertion that currently reads "assert a > 0 and
torch.isfinite(torch.tensor(a))" to "assert a > 0 and math.isfinite(a)"; refer
to the variable "a" in the failing assertion to locate the line to update.

24-38: 💤 Low value

Move the IS_AVAILABLE import up with the other top-level imports.

Line 38 is a module-level import placed below pytestmark. It is used in the skipif decorators, so grouping it with lines 24-31 keeps import ordering conventional and avoids surprise. Minor placement nit.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`
around lines 24 - 38, Move the module-level import "from
modelopt.torch.kernels.common.attention import IS_AVAILABLE as
TRITON_KERNEL_AVAILABLE" up into the top import block with the other imports
(near AutoModelForCausalLM and SparseAttentionModule) so it's defined before
pytestmark and available for the skipif decorators; update its position only (no
code changes) so skip conditions referencing TRITON_KERNEL_AVAILABLE remain
valid.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`:
- Around line 143-149: The in-function import of
modelopt.torch.kernels.common.attention.hf_triton_attention (and its symbols
clear_hf_triton_skip_softmax_config, get_calibration_counters,
get_calibration_seq_k, set_hf_triton_skip_softmax_config,
triton_attention_forward) must be either moved to module-level with the other
imports or the deferral must be explicitly justified; if this was done to avoid
a hard dependency on Triton at test-collection time, move the import to the top
of the file and/or add a concise comment like "deferred import to avoid hard
Triton dependency during test collection" immediately above the in-function
import so the reason for the optional-dependency deferral is clear and follows
coding guidelines.
- Around line 86-94: Move the local "import copy" out of the test method and
place it with the other top-level imports at the top of the file so import
errors are raised at collection time; specifically remove the inline import
inside test_sparsify_triton_calib_sets_params and add a single top-level "import
copy" that will be used by that function (and any other tests) to deepcopy
SKIP_SOFTMAX_TRITON_CALIB.

---

Outside diff comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 176-182: The current override always writes both "prefill" and
"decode" into calib["target_sparse_ratio"], which forces decode calibration on
even for prefill-only configs; instead, fetch the existing target dict (e.g.,
target = calib.get("target_sparse_ratio", {})) and only update the phases that
are already present in that dict (or at least only set "prefill" if "decode" is
not present) when applying args.target_sparse_ratio; in practice modify the
block that uses sparse_cfg.setdefault("calibration", {}) so it merges
args.target_sparse_ratio into the existing target dict key-by-key (updating only
existing phase keys like "prefill" and "decode") rather than unconditionally
writing both phases.

---

Nitpick comments:
In `@modelopt/torch/kernels/common/attention/hf_triton_attention.py`:
- Line 176: Add a short explanatory comment for the deferred import of
attention_calibrate to indicate it's performed lazily for calibration only;
locate the import statement "from modelopt.torch.kernels.common.attention import
attention_calibrate" in hf_triton_attention.py and add a brief comment such as
"# Lazy import: only needed during calibration" immediately above or inline with
that import.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`:
- Line 114: Replace the tensor-based finiteness check with Python's
math.isfinite: add "import math" at the top of the test file and change the
assertion that currently reads "assert a > 0 and
torch.isfinite(torch.tensor(a))" to "assert a > 0 and math.isfinite(a)"; refer
to the variable "a" in the failing assertion to locate the line to update.
- Around line 24-38: Move the module-level import "from
modelopt.torch.kernels.common.attention import IS_AVAILABLE as
TRITON_KERNEL_AVAILABLE" up into the top import block with the other imports
(near AutoModelForCausalLM and SparseAttentionModule) so it's defined before
pytestmark and available for the skipif decorators; update its position only (no
code changes) so skip conditions referencing TRITON_KERNEL_AVAILABLE remain
valid.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: dc3aa8c1-a287-45c1-8196-2702a9981726

📥 Commits

Reviewing files that changed from the base of the PR and between 905259f and 7802f9f.

📒 Files selected for processing (5)
  • examples/llm_sparsity/attention_sparsity/hf_sa.py
  • modelopt/torch/kernels/common/attention/hf_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py

Comment on lines +86 to +94
def test_sparsify_triton_calib_sets_params(self, tiny_llama_dir):
"""Running SKIP_SOFTMAX_TRITON_CALIB fits a finite exponential model."""
import copy

model = _load_eager(tiny_llama_dir)

# Use the calibrator's default (dense) threshold trials so the collected
# sparsity densely covers the (10%, 90%) window the fit filters on.
config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Move import copy to the top of the file.

import copy is a standard-library import placed inside the test method (also at line 121). Top-level imports surface import errors at collection time rather than mid-test.

💚 Suggested change
 import pytest
 import torch
+import copy
     def test_sparsify_triton_calib_sets_params(self, tiny_llama_dir):
         """Running SKIP_SOFTMAX_TRITON_CALIB fits a finite exponential model."""
-        import copy
-
         model = _load_eager(tiny_llama_dir)
     def test_calibrated_model_inference(self, tiny_llama_dir):
         """A model calibrated through the Triton path still runs inference cleanly."""
-        import copy
-
         model = _load_eager(tiny_llama_dir)
As per coding guidelines: "Imports inside functions or test methods without explicit justification... Imports belong at the top of the file so import errors surface at collection time."
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_sparsify_triton_calib_sets_params(self, tiny_llama_dir):
"""Running SKIP_SOFTMAX_TRITON_CALIB fits a finite exponential model."""
import copy
model = _load_eager(tiny_llama_dir)
# Use the calibrator's default (dense) threshold trials so the collected
# sparsity densely covers the (10%, 90%) window the fit filters on.
config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB)
def test_sparsify_triton_calib_sets_params(self, tiny_llama_dir):
"""Running SKIP_SOFTMAX_TRITON_CALIB fits a finite exponential model."""
model = _load_eager(tiny_llama_dir)
# Use the calibrator's default (dense) threshold trials so the collected
# sparsity densely covers the (10%, 90%) window the fit filters on.
config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`
around lines 86 - 94, Move the local "import copy" out of the test method and
place it with the other top-level imports at the top of the file so import
errors are raised at collection time; specifically remove the inline import
inside test_sparsify_triton_calib_sets_params and add a single top-level "import
copy" that will be used by that function (and any other tests) to deepcopy
SKIP_SOFTMAX_TRITON_CALIB.

Comment on lines +143 to +149
from modelopt.torch.kernels.common.attention.hf_triton_attention import (
clear_hf_triton_skip_softmax_config,
get_calibration_counters,
get_calibration_seq_k,
set_hf_triton_skip_softmax_config,
triton_attention_forward,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Lift the hf_triton_attention import to module scope, or justify the deferral.

This in-method import has no comment explaining why it is deferred. If it is to avoid a hard triton import at collection time, that is an acceptable optional-dependency deferral but it must carry a brief comment naming the reason; otherwise move it to the top of the file with the other imports.

As per coding guidelines: "The only acceptable in-function imports are for circular imports or optional dependencies... and those should carry a brief comment naming the reason."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`
around lines 143 - 149, The in-function import of
modelopt.torch.kernels.common.attention.hf_triton_attention (and its symbols
clear_hf_triton_skip_softmax_config, get_calibration_counters,
get_calibration_seq_k, set_hf_triton_skip_softmax_config,
triton_attention_forward) must be either moved to module-level with the other
imports or the deferral must be explicitly justified; if this was done to avoid
a hard dependency on Triton at test-collection time, move the import to the top
of the file and/or add a concise comment like "deferred import to avoid hard
Triton dependency during test collection" immediately above the in-function
import so the reason for the optional-dependency deferral is clear and follows
coding guidelines.

@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 2, 2026

Codecov Report

❌ Patch coverage is 73.68421% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 68.13%. Comparing base (905259f) to head (7802f9f).

Files with missing lines Patch % Lines
...ch/kernels/common/attention/hf_triton_attention.py 61.53% 10 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (905259f) and HEAD (7802f9f). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (905259f) HEAD (7802f9f)
gpu 4 3
examples 12 11
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1597      +/-   ##
==========================================
- Coverage   77.38%   68.13%   -9.25%     
==========================================
  Files         479      479              
  Lines       52435    52473      +38     
==========================================
- Hits        40578    35754    -4824     
- Misses      11857    16719    +4862     
Flag Coverage Δ
examples 41.68% <55.26%> (+0.86%) ⬆️
gpu 27.94% <2.63%> (-32.50%) ⬇️
regression 15.19% <2.63%> (+0.06%) ⬆️
unit 53.63% <36.84%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant