Skip to content

Commit 8b01ba4

Browse files
rohansjoshikaix-nv
andauthored
Skip softmax calibration via Triton kernel (#1597)
### What does this PR do? Adds skip softmax calibration for LLMs via Triton kernel (leveraging existing kernel used for diffusion) Type of change: New feature <!-- Details about the change. --> ### Usage ``` python hf_sa.py --pyt_ckpt_path Qwen/Qwen3-8B --sparse_attn skip_softmax_triton_calib ``` The Triton calibration equals PyTorch at every threshold, for both phases: | threshold | prefill triton/pytorch | decode triton/pytorch | |------|------------------------|-----------------------| | 0.30 | 0.0% / 0.0% | 12.5% / 12.5% | | 0.50 | 0.0% / 0.0% | 37.5% / 37.5% | | 0.70 | 10.0% / 10.0% | 62.5% / 62.5% | ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> - Did you get Claude approval on this PR?: ✅ / ❌ / N/A <!--- Run `/claude review`. NVIDIA org members can self-trigger for complex changes; orthogonal to CodeRabbit. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a Triton-based skip-softmax sparse-attention calibration option and a CLI flag to override the calibration data directory (defaults to adjacent RULER data). * **Bug Fixes** * Ensure calibration kernels run on the correct CUDA device; align measurement granularity and tile/block sizing; ignore padded query rows when counting skippable tiles. * **Tests** * Added GPU Triton calibration tests for end-to-end inference, multi-threshold stats, and decode-phase reporting. * **Documentation** * Updated changelog and example to expose the new option and flag. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com> Signed-off-by: Kai Xu <kaix@nvidia.com> Co-authored-by: Kai Xu <kaix@nvidia.com>
1 parent e2c3e03 commit 8b01ba4

11 files changed

Lines changed: 553 additions & 140 deletions

File tree

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ Changelog
6161
- The PTQ example scripts ``examples/llm_ptq/hf_ptq.py``, ``examples/llm_ptq/multinode_ptq.py`` and ``examples/megatron_bridge/quantize.py`` now derive their ``--qformat`` / ``--kv_cache_qformat`` (``--quant_cfg`` / ``--kv_cache_quant`` for Megatron-Bridge) CLI vocabularies by discovering the YAML presets under ``modelopt_recipes/configs/ptq/presets/{model,kv}/`` rather than carrying hardcoded ``QUANT_CFG_CHOICES`` / ``KV_QUANT_CFG_CHOICES`` tables. The discovery helper, alias table and ready-built ``QUANT_CFG_CHOICES`` / ``KV_QUANT_CFG_CHOICES`` mappings now live in ``modelopt.recipe.presets`` and are shared by all three scripts. Presets are loaded eagerly into a plain dict at import. Adding a new preset YAML makes it available on the CLI of all three with no script change — note this means each script now accepts every preset under those directories, not just a previously curated subset. All previously-supported short names (``int8_sq``, ``nvfp4_awq``, ``fp8_pb_wo``, ``nvfp4_mse``, ``w4a8_awq``, ``nvfp4_local_hessian``, ``fp8_pc_pt``, ``int8_wo``) keep working via a small deprecation alias table; new formats should be exposed as preset YAMLs (or, longer term, as full ``--recipe`` recipes).
6262
- Add ``configs/ptq/presets/kv/fp8_cast.yaml`` and ``configs/ptq/presets/kv/nvfp4_cast.yaml``, promoting ``fp8_cast`` / ``nvfp4_cast`` to first-class KV presets composed from the existing ``kv_fp8_cast`` / ``kv_nvfp4_cast`` unit fragments. The previous runtime ``use_constant_amax`` post-edit in ``hf_ptq.py`` is removed; ``use_constant_amax: true`` now lives in the YAML and is therefore authoritative. **Custom (out-of-tree) recipes that target a cast KV format must set ``use_constant_amax: true`` themselves on the ``[kv]_bmm_quantizer`` config** — in-tree recipes already do via the ``kv_*_cast`` units.
6363
- Add DMD2 distillation for few-step diffusion models in ``examples/diffusers/fastgen/``: distill Qwen-Image into a 4/8-step student via Distribution Matching Distillation. See `examples/diffusers/fastgen/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/diffusers/fastgen>`_ for details.
64+
- Add post-training quantization (PTQ) example for the Megatron-Bridge framework: ``examples/megatron_bridge/quantize.py`` calibrates an HF model (via ``--quant_cfg`` alias / full config name or a ``--recipe`` YAML, with optional KV-cache quant, weight-only, compression, and MoE expert-ratio calibration) and saves a Megatron checkpoint (tensor / pipeline / expert parallelism supported), and ``examples/megatron_bridge/export.py`` converts that checkpoint to a deployable HuggingFace (unified) checkpoint for TensorRT-LLM / vLLM / SGLang. See `examples/megatron_bridge/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge>`_ for details.
65+
- Add ``mtsa.config.SKIP_SOFTMAX_TRITON_CALIB`` for skip-softmax attention-sparsity calibration through the fused Triton ``attention_calibrate`` kernel (HF ``modelopt_triton`` backend), measuring multi-threshold tile-skip statistics the way the Triton inference kernel actually skips tiles for both prefill and decode. Exposed as ``--sparse_attn_cfg skip_softmax_triton_calib`` in ``examples/llm_sparsity/attention_sparsity/hf_sa.py`` (with a new ``--calib_data_dir`` flag for RULER calibration data).
6466

6567
**Bug Fixes**
6668

examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from modelopt.torch.sparsity.attention_sparsity.config import (
3232
SKIP_SOFTMAX_CALIB,
3333
SKIP_SOFTMAX_CALIB_SPARSE24,
34+
SKIP_SOFTMAX_TRITON_CALIB,
3435
SPARSE_SOFTMAX_DEFAULT,
3536
)
3637
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
@@ -44,6 +45,7 @@
4445
SPARSE_ATTN_CFG_CHOICES = {
4546
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
4647
"skip_softmax_calib_sparse24": SKIP_SOFTMAX_CALIB_SPARSE24,
48+
"skip_softmax_triton_calib": SKIP_SOFTMAX_TRITON_CALIB,
4749
"sparse_softmax": SPARSE_SOFTMAX_DEFAULT,
4850
}
4951

@@ -186,6 +188,15 @@ def main(args):
186188
calib["max_seqlen"] = args.calib_max_seqlen
187189
if args.calib_chunk_size is not None:
188190
calib["chunk_size"] = args.calib_chunk_size
191+
# Point RULER calibration at the data downloaded by download_ruler_data.sh
192+
# (next to this script) unless the user overrides it. The NIAH essay
193+
# haystack requires this directory.
194+
calib.setdefault(
195+
"data_dir",
196+
args.calib_data_dir
197+
if args.calib_data_dir is not None
198+
else str(Path(__file__).parent / "data"),
199+
)
189200

190201
model = mtsa.sparsify(model, config=sparse_config)
191202
print("Sparse attention applied successfully!")
@@ -302,6 +313,14 @@ def main(args):
302313
default=None,
303314
help="Chunk size for calibration prefill. Overrides config value.",
304315
)
316+
parser.add_argument(
317+
"--calib_data_dir",
318+
type=str,
319+
default=None,
320+
help="Path to RULER calibration data (contains an 'essays' subdir). "
321+
"Defaults to the 'data' directory next to this script "
322+
"(populated by download_ruler_data.sh).",
323+
)
305324

306325
args = parser.parse_args()
307326
main(args)

modelopt/torch/kernels/common/attention/hf_triton_attention.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727

2828
from modelopt.torch.kernels.common.attention.triton_fa import attention
2929

30+
# Skip-softmax calibration config and counters live on the module's
31+
# ``_sparse_method_instance`` (HF passes the owning module to
32+
# ``triton_attention_forward``), so no separate thread-local state is needed.
33+
3034

3135
def _seq_lens_from_mask(
3236
attention_mask: torch.Tensor | None,
@@ -105,20 +109,49 @@ def triton_attention_forward(
105109
kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
106110
kw["max_input_len_k"] = seq_k
107111

108-
# Sparse attention params
112+
# Sparse-attention method instance. It carries the inference threshold and,
113+
# during calibration, both the calibration config and the accumulated
114+
# tile-skip counters. Available here because HF passes the owning module.
109115
method = getattr(module, "_sparse_method_instance", None)
110116

117+
# Calibration mode: run the calibration kernel, which computes full attention
118+
# while counting, per candidate threshold, how many KV tiles would be skipped.
119+
# The sparse-attention kwargs below are intentionally not added in this branch.
120+
if method is not None and getattr(method, "_calibration_mode", False):
121+
trials = getattr(method, "_threshold_trials", None)
122+
# Deferred: the package __init__ imports this module, so importing
123+
# attention_calibrate at module top would be circular.
124+
from modelopt.torch.kernels.common.attention import attention_calibrate
125+
126+
if trials and attention_calibrate is not None:
127+
o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials)
128+
129+
# Accumulate counters across all attention calls in this forward pass.
130+
# The method instance is per-module so the accumulator stays on one
131+
# device, but guard the add against a device mismatch just in case.
132+
prev = getattr(method, "_hf_calibration_counters", None)
133+
method._hf_calibration_counters = (
134+
counters if prev is None else prev + counters.to(prev.device)
135+
)
136+
method._hf_calibration_seq_k = seq_k
137+
method._hf_calibration_is_decode = is_decode
138+
139+
return (o.view(batch, seq_len, num_heads, head_dim), None)
140+
111141
# N:M sparse softmax: prefill only (no perf benefit for decode)
112142
if method is not None and not is_decode and getattr(module, "_apply_sparse_nm", False):
113143
kw["sparsity_n"] = method.sparsity_n
114144
kw["sparsity_m"] = method.sparsity_m
115145
kw["dense_sink_tokens"] = method.dense_sink_tokens
116146
kw["dense_recent_tokens"] = method.dense_recent_tokens
117147

118-
# Skip-softmax: applies to both prefill and decode
148+
# Skip-softmax: applies to both prefill and decode. Prefer the method's
149+
# per-phase calibrated dynamic threshold (scale_factor / seq_k); fall back
150+
# to the static threshold when uncalibrated.
119151
if method is not None and getattr(module, "_apply_skip_softmax", False):
120-
if method.skip_softmax_threshold:
121-
kw["skip_softmax_threshold"] = method.skip_softmax_threshold
152+
threshold = method.get_inference_threshold(seq_len, seq_k)
153+
if threshold:
154+
kw["skip_softmax_threshold"] = threshold
122155

123156
o = attention(q, k, v, **kw)
124157

modelopt/torch/kernels/common/attention/triton_fa.py

Lines changed: 102 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def _load_sparsity_helpers() -> None:
8080
_FWD_CONFIGS = [triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=1, num_warps=4)]
8181

8282
_MEASURE_BLOCK_M = 128
83-
_MEASURE_BLOCK_N = 64
83+
# 128 so the kernel sparsity-measurement block matches the PyTorch
84+
# flash_skip_softmax calibration block (br = bc = 128) and the Triton
85+
# calibration kernel; otherwise the two measure at different granularities.
86+
_MEASURE_BLOCK_N = 128
8487
_MEASURE_NUM_STAGES = 1
8588
_MEASURE_NUM_WARPS = 4
8689

@@ -363,6 +366,8 @@ def _attn_fwd(
363366
skip_tile = _skip_softmax_decision(
364367
scores,
365368
row_max,
369+
q_pos,
370+
seq_len_q,
366371
SKIP_THRESHOLD_LOG2,
367372
Sparsity_total,
368373
Sparsity_skipped,
@@ -919,23 +924,29 @@ def forward(
919924
def grid(META):
920925
return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"]))
921926

922-
if do_measure:
923-
# Runtime counters mutate global tensors, so do not run them through
924-
# autotune candidate trials. Use one stable config for measurement.
925-
_attn_fwd.fn[grid](
926-
*fwd_args,
927-
**fwd_kwargs,
928-
BLOCK_M=_MEASURE_BLOCK_M,
929-
BLOCK_N=_MEASURE_BLOCK_N,
930-
num_warps=_MEASURE_NUM_WARPS,
931-
num_stages=_MEASURE_NUM_STAGES,
932-
)
933-
else:
934-
_attn_fwd[grid](
935-
*fwd_args,
936-
**fwd_kwargs,
937-
# BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune
938-
)
927+
# Triton launches on torch.cuda.current_device(), which is not
928+
# necessarily the device the tensors live on (e.g. under accelerate
929+
# device_map="auto" sharding). Activate the tensor's device so the
930+
# kernel dereferences the right pointers instead of triggering an
931+
# illegal memory access.
932+
with torch.cuda.device(q.device):
933+
if do_measure:
934+
# Runtime counters mutate global tensors, so do not run them through
935+
# autotune candidate trials. Use one stable config for measurement.
936+
_attn_fwd.fn[grid](
937+
*fwd_args,
938+
**fwd_kwargs,
939+
BLOCK_M=_MEASURE_BLOCK_M,
940+
BLOCK_N=_MEASURE_BLOCK_N,
941+
num_warps=_MEASURE_NUM_WARPS,
942+
num_stages=_MEASURE_NUM_STAGES,
943+
)
944+
else:
945+
_attn_fwd[grid](
946+
*fwd_args,
947+
**fwd_kwargs,
948+
# BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune
949+
)
939950

940951
# Store sparsity counters on the output tensor for retrieval by callers
941952
if do_measure:
@@ -970,23 +981,30 @@ def backward(ctx, grad_output):
970981
do = grad_output.contiguous()
971982
num_warps = 4
972983

984+
# Triton launches on torch.cuda.current_device(), which is not
985+
# necessarily the device the tensors live on (e.g. under accelerate
986+
# device_map="auto" sharding). Activate the tensor's device for each
987+
# launch so the kernels dereference the right pointers instead of
988+
# triggering an illegal memory access.
989+
973990
# Phase 1: delta = rowsum(O * dO)
974991
delta = torch.empty_like(lse)
975-
_attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))](
976-
o,
977-
do,
978-
delta,
979-
o.stride(0),
980-
o.stride(1),
981-
do.stride(0),
982-
do.stride(1),
983-
delta.stride(0),
984-
delta.stride(1),
985-
q.shape[0],
986-
HEAD_DIM=HEAD_DIM,
987-
BLOCK_D=BLOCK_D,
988-
BLOCK_M=BLOCK,
989-
)
992+
with torch.cuda.device(q.device):
993+
_attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))](
994+
o,
995+
do,
996+
delta,
997+
o.stride(0),
998+
o.stride(1),
999+
do.stride(0),
1000+
do.stride(1),
1001+
delta.stride(0),
1002+
delta.stride(1),
1003+
q.shape[0],
1004+
HEAD_DIM=HEAD_DIM,
1005+
BLOCK_D=BLOCK_D,
1006+
BLOCK_M=BLOCK,
1007+
)
9901008

9911009
dq = torch.zeros_like(q)
9921010
dk = torch.zeros_like(k)
@@ -1016,57 +1034,59 @@ def backward(ctx, grad_output):
10161034
)
10171035

10181036
# Phase 2: dK, dV
1019-
_attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))](
1020-
*bwd_args[:4],
1021-
dk,
1022-
dv,
1023-
*bwd_args[4:],
1024-
dk.stride(0),
1025-
dk.stride(1),
1026-
dv.stride(0),
1027-
dv.stride(1),
1028-
lse.stride(0),
1029-
lse.stride(1),
1030-
kv_group_num=ctx.kv_group_num,
1031-
BLOCK_M=BLOCK,
1032-
BLOCK_D=BLOCK_D,
1033-
BLOCK_N=BLOCK,
1034-
IS_CAUSAL=ctx.is_causal,
1035-
HEAD_DIM=HEAD_DIM,
1036-
SPARSITY_N=ctx.sparsity_n,
1037-
SPARSITY_M=ctx.sparsity_m,
1038-
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
1039-
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
1040-
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
1041-
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
1042-
num_warps=num_warps,
1043-
num_stages=1,
1044-
)
1037+
with torch.cuda.device(q.device):
1038+
_attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))](
1039+
*bwd_args[:4],
1040+
dk,
1041+
dv,
1042+
*bwd_args[4:],
1043+
dk.stride(0),
1044+
dk.stride(1),
1045+
dv.stride(0),
1046+
dv.stride(1),
1047+
lse.stride(0),
1048+
lse.stride(1),
1049+
kv_group_num=ctx.kv_group_num,
1050+
BLOCK_M=BLOCK,
1051+
BLOCK_D=BLOCK_D,
1052+
BLOCK_N=BLOCK,
1053+
IS_CAUSAL=ctx.is_causal,
1054+
HEAD_DIM=HEAD_DIM,
1055+
SPARSITY_N=ctx.sparsity_n,
1056+
SPARSITY_M=ctx.sparsity_m,
1057+
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
1058+
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
1059+
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
1060+
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
1061+
num_warps=num_warps,
1062+
num_stages=1,
1063+
)
10451064

10461065
# Phase 3: dQ
1047-
_attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))](
1048-
*bwd_args[:4],
1049-
dq,
1050-
*bwd_args[4:],
1051-
dq.stride(0),
1052-
dq.stride(1),
1053-
lse.stride(0),
1054-
lse.stride(1),
1055-
kv_group_num=ctx.kv_group_num,
1056-
BLOCK_M=BLOCK,
1057-
BLOCK_D=BLOCK_D,
1058-
BLOCK_N=BLOCK,
1059-
IS_CAUSAL=ctx.is_causal,
1060-
HEAD_DIM=HEAD_DIM,
1061-
SPARSITY_N=ctx.sparsity_n,
1062-
SPARSITY_M=ctx.sparsity_m,
1063-
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
1064-
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
1065-
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
1066-
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
1067-
num_warps=num_warps,
1068-
num_stages=1,
1069-
)
1066+
with torch.cuda.device(q.device):
1067+
_attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))](
1068+
*bwd_args[:4],
1069+
dq,
1070+
*bwd_args[4:],
1071+
dq.stride(0),
1072+
dq.stride(1),
1073+
lse.stride(0),
1074+
lse.stride(1),
1075+
kv_group_num=ctx.kv_group_num,
1076+
BLOCK_M=BLOCK,
1077+
BLOCK_D=BLOCK_D,
1078+
BLOCK_N=BLOCK,
1079+
IS_CAUSAL=ctx.is_causal,
1080+
HEAD_DIM=HEAD_DIM,
1081+
SPARSITY_N=ctx.sparsity_n,
1082+
SPARSITY_M=ctx.sparsity_m,
1083+
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
1084+
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
1085+
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
1086+
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
1087+
num_warps=num_warps,
1088+
num_stages=1,
1089+
)
10701090

10711091
return (
10721092
dq,

0 commit comments

Comments
 (0)