Skip softmax calibration via Triton kernel#1597
Conversation
Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
📝 WalkthroughWalkthroughThis PR adds HuggingFace Triton skip-softmax calibration support by introducing thread-local calibration state management in the HF attention backend, defining a prefill-focused configuration, integrating HF Triton into the existing calibration orchestration, extending the CLI example with data directory handling, and validating the implementation with GPU tests. ChangesHF Triton Skip-Softmax Calibration Support
🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/llm_sparsity/attention_sparsity/hf_sa.py (1)
176-182:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winFix
--target_sparse_ratiooverride to avoid enabling decode calibration when Triton skip-softmax is prefill-only.
calibration/calibrate.pygates decode calibration ontarget_sparse_ratio["decode"] > 0(viacalibrate_decode = target_dict.get("decode", 0.0) > 0.0), so the current override inexamples/llm_sparsity/attention_sparsity/hf_sa.pywill forcedecodeto be enabled even when the selected Triton calibration config is meant to be prefill-only (SKIP_SOFTMAX_TRITON_CALIB). Preserve the existingtarget_sparse_ratiophase set instead of unconditionally writing both phases.🔧 Proposed fix to only override phases already present
if args.target_sparse_ratio is not None: calib = sparse_cfg.setdefault("calibration", {}) assert isinstance(calib, dict) - calib["target_sparse_ratio"] = { - "prefill": args.target_sparse_ratio, - "decode": args.target_sparse_ratio, - } + existing = calib.get("target_sparse_ratio", {"prefill": 0.5, "decode": 0.5}) + calib["target_sparse_ratio"] = { + phase: args.target_sparse_ratio for phase in existing + }🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/llm_sparsity/attention_sparsity/hf_sa.py` around lines 176 - 182, The current override always writes both "prefill" and "decode" into calib["target_sparse_ratio"], which forces decode calibration on even for prefill-only configs; instead, fetch the existing target dict (e.g., target = calib.get("target_sparse_ratio", {})) and only update the phases that are already present in that dict (or at least only set "prefill" if "decode" is not present) when applying args.target_sparse_ratio; in practice modify the block that uses sparse_cfg.setdefault("calibration", {}) so it merges args.target_sparse_ratio into the existing target dict key-by-key (updating only existing phase keys like "prefill" and "decode") rather than unconditionally writing both phases.
🧹 Nitpick comments (3)
modelopt/torch/kernels/common/attention/hf_triton_attention.py (1)
176-176: 💤 Low valueAdd a brief comment explaining the deferred import.
Per CONTRIBUTING.md, function-level imports should include a brief comment naming the reason (e.g., lazy/optional/circular). This import is deferred to the calibration path—add a comment such as
# Lazy import: only needed during calibration.📝 Suggested comment
- from modelopt.torch.kernels.common.attention import attention_calibrate + # Lazy import: only needed during calibration. + from modelopt.torch.kernels.common.attention import attention_calibrate🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/kernels/common/attention/hf_triton_attention.py` at line 176, Add a short explanatory comment for the deferred import of attention_calibrate to indicate it's performed lazily for calibration only; locate the import statement "from modelopt.torch.kernels.common.attention import attention_calibrate" in hf_triton_attention.py and add a brief comment such as "# Lazy import: only needed during calibration" immediately above or inline with that import.tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py (2)
114-114: 💤 Low valuePrefer
math.isfinite(a)over wrapping a Python scalar in a tensor.
ais already a Python float, sotorch.isfinite(torch.tensor(a))allocates a tensor unnecessarily.math.isfinite(a)is clearer.♻️ Suggested change
- assert a > 0 and torch.isfinite(torch.tensor(a)) + assert a > 0 and math.isfinite(a)(add
import mathat the top of the file)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py` at line 114, Replace the tensor-based finiteness check with Python's math.isfinite: add "import math" at the top of the test file and change the assertion that currently reads "assert a > 0 and torch.isfinite(torch.tensor(a))" to "assert a > 0 and math.isfinite(a)"; refer to the variable "a" in the failing assertion to locate the line to update.
24-38: 💤 Low valueMove the
IS_AVAILABLEimport up with the other top-level imports.Line 38 is a module-level import placed below
pytestmark. It is used in theskipifdecorators, so grouping it with lines 24-31 keeps import ordering conventional and avoids surprise. Minor placement nit.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py` around lines 24 - 38, Move the module-level import "from modelopt.torch.kernels.common.attention import IS_AVAILABLE as TRITON_KERNEL_AVAILABLE" up into the top import block with the other imports (near AutoModelForCausalLM and SparseAttentionModule) so it's defined before pytestmark and available for the skipif decorators; update its position only (no code changes) so skip conditions referencing TRITON_KERNEL_AVAILABLE remain valid.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`:
- Around line 143-149: The in-function import of
modelopt.torch.kernels.common.attention.hf_triton_attention (and its symbols
clear_hf_triton_skip_softmax_config, get_calibration_counters,
get_calibration_seq_k, set_hf_triton_skip_softmax_config,
triton_attention_forward) must be either moved to module-level with the other
imports or the deferral must be explicitly justified; if this was done to avoid
a hard dependency on Triton at test-collection time, move the import to the top
of the file and/or add a concise comment like "deferred import to avoid hard
Triton dependency during test collection" immediately above the in-function
import so the reason for the optional-dependency deferral is clear and follows
coding guidelines.
- Around line 86-94: Move the local "import copy" out of the test method and
place it with the other top-level imports at the top of the file so import
errors are raised at collection time; specifically remove the inline import
inside test_sparsify_triton_calib_sets_params and add a single top-level "import
copy" that will be used by that function (and any other tests) to deepcopy
SKIP_SOFTMAX_TRITON_CALIB.
---
Outside diff comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 176-182: The current override always writes both "prefill" and
"decode" into calib["target_sparse_ratio"], which forces decode calibration on
even for prefill-only configs; instead, fetch the existing target dict (e.g.,
target = calib.get("target_sparse_ratio", {})) and only update the phases that
are already present in that dict (or at least only set "prefill" if "decode" is
not present) when applying args.target_sparse_ratio; in practice modify the
block that uses sparse_cfg.setdefault("calibration", {}) so it merges
args.target_sparse_ratio into the existing target dict key-by-key (updating only
existing phase keys like "prefill" and "decode") rather than unconditionally
writing both phases.
---
Nitpick comments:
In `@modelopt/torch/kernels/common/attention/hf_triton_attention.py`:
- Line 176: Add a short explanatory comment for the deferred import of
attention_calibrate to indicate it's performed lazily for calibration only;
locate the import statement "from modelopt.torch.kernels.common.attention import
attention_calibrate" in hf_triton_attention.py and add a brief comment such as
"# Lazy import: only needed during calibration" immediately above or inline with
that import.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`:
- Line 114: Replace the tensor-based finiteness check with Python's
math.isfinite: add "import math" at the top of the test file and change the
assertion that currently reads "assert a > 0 and
torch.isfinite(torch.tensor(a))" to "assert a > 0 and math.isfinite(a)"; refer
to the variable "a" in the failing assertion to locate the line to update.
- Around line 24-38: Move the module-level import "from
modelopt.torch.kernels.common.attention import IS_AVAILABLE as
TRITON_KERNEL_AVAILABLE" up into the top import block with the other imports
(near AutoModelForCausalLM and SparseAttentionModule) so it's defined before
pytestmark and available for the skipif decorators; update its position only (no
code changes) so skip conditions referencing TRITON_KERNEL_AVAILABLE remain
valid.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: dc3aa8c1-a287-45c1-8196-2702a9981726
📒 Files selected for processing (5)
examples/llm_sparsity/attention_sparsity/hf_sa.pymodelopt/torch/kernels/common/attention/hf_triton_attention.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pytests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py
| def test_sparsify_triton_calib_sets_params(self, tiny_llama_dir): | ||
| """Running SKIP_SOFTMAX_TRITON_CALIB fits a finite exponential model.""" | ||
| import copy | ||
|
|
||
| model = _load_eager(tiny_llama_dir) | ||
|
|
||
| # Use the calibrator's default (dense) threshold trials so the collected | ||
| # sparsity densely covers the (10%, 90%) window the fit filters on. | ||
| config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB) |
There was a problem hiding this comment.
Move import copy to the top of the file.
import copy is a standard-library import placed inside the test method (also at line 121). Top-level imports surface import errors at collection time rather than mid-test.
💚 Suggested change
import pytest
import torch
+import copy def test_sparsify_triton_calib_sets_params(self, tiny_llama_dir):
"""Running SKIP_SOFTMAX_TRITON_CALIB fits a finite exponential model."""
- import copy
-
model = _load_eager(tiny_llama_dir) def test_calibrated_model_inference(self, tiny_llama_dir):
"""A model calibrated through the Triton path still runs inference cleanly."""
- import copy
-
model = _load_eager(tiny_llama_dir)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def test_sparsify_triton_calib_sets_params(self, tiny_llama_dir): | |
| """Running SKIP_SOFTMAX_TRITON_CALIB fits a finite exponential model.""" | |
| import copy | |
| model = _load_eager(tiny_llama_dir) | |
| # Use the calibrator's default (dense) threshold trials so the collected | |
| # sparsity densely covers the (10%, 90%) window the fit filters on. | |
| config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB) | |
| def test_sparsify_triton_calib_sets_params(self, tiny_llama_dir): | |
| """Running SKIP_SOFTMAX_TRITON_CALIB fits a finite exponential model.""" | |
| model = _load_eager(tiny_llama_dir) | |
| # Use the calibrator's default (dense) threshold trials so the collected | |
| # sparsity densely covers the (10%, 90%) window the fit filters on. | |
| config = copy.deepcopy(SKIP_SOFTMAX_TRITON_CALIB) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`
around lines 86 - 94, Move the local "import copy" out of the test method and
place it with the other top-level imports at the top of the file so import
errors are raised at collection time; specifically remove the inline import
inside test_sparsify_triton_calib_sets_params and add a single top-level "import
copy" that will be used by that function (and any other tests) to deepcopy
SKIP_SOFTMAX_TRITON_CALIB.
| from modelopt.torch.kernels.common.attention.hf_triton_attention import ( | ||
| clear_hf_triton_skip_softmax_config, | ||
| get_calibration_counters, | ||
| get_calibration_seq_k, | ||
| set_hf_triton_skip_softmax_config, | ||
| triton_attention_forward, | ||
| ) |
There was a problem hiding this comment.
Lift the hf_triton_attention import to module scope, or justify the deferral.
This in-method import has no comment explaining why it is deferred. If it is to avoid a hard triton import at collection time, that is an acceptable optional-dependency deferral but it must carry a brief comment naming the reason; otherwise move it to the top of the file with the other imports.
As per coding guidelines: "The only acceptable in-function imports are for circular imports or optional dependencies... and those should carry a brief comment naming the reason."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_calibration_gpu.py`
around lines 143 - 149, The in-function import of
modelopt.torch.kernels.common.attention.hf_triton_attention (and its symbols
clear_hf_triton_skip_softmax_config, get_calibration_counters,
get_calibration_seq_k, set_hf_triton_skip_softmax_config,
triton_attention_forward) must be either moved to module-level with the other
imports or the deferral must be explicitly justified; if this was done to avoid
a hard dependency on Triton at test-collection time, move the import to the top
of the file and/or add a concise comment like "deferred import to avoid hard
Triton dependency during test collection" immediately above the in-function
import so the reason for the optional-dependency deferral is clear and follows
coding guidelines.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1597 +/- ##
==========================================
- Coverage 77.38% 68.13% -9.25%
==========================================
Files 479 479
Lines 52435 52473 +38
==========================================
- Hits 40578 35754 -4824
- Misses 11857 16719 +4862
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What does this PR do?
Adds skip softmax calibration for LLMs via Triton kernel (leveraging existing kernel used for diffusion)
Type of change: New feature
Usage
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
--calib_data_dirto customize calibration data directory paths.Tests