Skip to content

Commit 8490c42

Browse files
committed
Added decode calibration
Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
1 parent a5772b3 commit 8490c42

9 files changed

Lines changed: 298 additions & 316 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Changelog
4343
- Add Nemotron-3-Super-120B-A12B PTQ recipes ``modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml`` (MSE-mixed) and ``super-nvfp4-max-calib.yaml`` (max-calib mixed): NVFP4 W4A4 routed experts + FP8 per-tensor shared experts / Mamba in/out_proj + FP8 KV cache.
4444
- Add quantized ``nn.Embedding`` support. ``nn.Embedding`` is now registered in ``QuantModuleRegistry`` and exposes ``weight_quantizer`` (embedding table), ``output_quantizer`` (lookup activations), and a permanently disabled ``input_quantizer`` placeholder — embedding inputs are integer indices and cannot be fake-quantized, so direct ``enable*()`` calls raise. ``export_hf_checkpoint`` packs quantized embedding weights alongside Linear layers. Embedding quantizers are opt-in (``parent_class: nn.Embedding`` disabled by default).
4545
- Add post-training quantization (PTQ) example for the Megatron-Bridge framework: ``examples/megatron_bridge/quantize.py`` calibrates an HF model (via ``--quant_cfg`` alias / full config name or a ``--recipe`` YAML, with optional KV-cache quant, weight-only, compression, and MoE expert-ratio calibration) and saves a Megatron checkpoint (tensor / pipeline / expert parallelism supported), and ``examples/megatron_bridge/export.py`` converts that checkpoint to a deployable HuggingFace (unified) checkpoint for TensorRT-LLM / vLLM / SGLang. See `examples/megatron_bridge/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge>`_ for details.
46+
- Add ``mtsa.config.SKIP_SOFTMAX_TRITON_CALIB`` for skip-softmax attention-sparsity calibration through the fused Triton ``attention_calibrate`` kernel (HF ``modelopt_triton`` backend), measuring multi-threshold tile-skip statistics the way the Triton inference kernel actually skips tiles for both prefill and decode. Exposed as ``--sparse_attn_cfg skip_softmax_triton_calib`` in ``examples/llm_sparsity/attention_sparsity/hf_sa.py`` (with a new ``--calib_data_dir`` flag for RULER calibration data).
4647

4748
**Bug Fixes**
4849

modelopt/torch/kernels/common/attention/hf_triton_attention.py

Lines changed: 23 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -22,70 +22,14 @@
2222

2323
from __future__ import annotations
2424

25-
import threading
26-
2725
import torch
2826
import torch.nn as nn
2927

3028
from modelopt.torch.kernels.common.attention.triton_fa import attention
3129

32-
# ---------------------------------------------------------------------------
33-
# Thread-local skip-softmax calibration config for the HF (modelopt_triton) backend
34-
# ---------------------------------------------------------------------------
35-
# Mirrors the diffusers/LTX backends: during calibration the Triton calibration
36-
# kernel measures multi-threshold tile-skip statistics without skipping any tiles.
37-
# Inference-time config (skip threshold / scale factor) is still read from the
38-
# module/method attributes in ``triton_attention_forward`` — only calibration
39-
# state lives here.
40-
_thread_local = threading.local()
41-
42-
43-
def set_hf_triton_skip_softmax_config(
44-
threshold: float | None = None,
45-
calibration_mode: bool = False,
46-
threshold_trials: list[float] | None = None,
47-
scale_factor: float | None = None,
48-
measure_sparsity: bool = False,
49-
) -> None:
50-
"""Set thread-local skip-softmax calibration config for the next forward.
51-
52-
Accepts the same keyword arguments as the diffusers/LTX backends so the
53-
shared :class:`TritonSkipSoftmaxMethod` can configure all backends uniformly.
54-
Only the calibration fields are consumed by the HF backend; the inference
55-
fields (``threshold``/``scale_factor``/``measure_sparsity``) are accepted for
56-
signature compatibility but ignored here, since the HF inference path reads
57-
its threshold from the module/method attributes.
58-
59-
Args:
60-
threshold: Ignored by the HF backend (inference threshold comes from the module).
61-
calibration_mode: If True, route prefill attention through the calibration kernel.
62-
threshold_trials: Thresholds to measure sparsity for (used when calibration_mode=True).
63-
scale_factor: Ignored by the HF backend.
64-
measure_sparsity: Ignored by the HF backend.
65-
"""
66-
_thread_local.calibration_mode = calibration_mode
67-
_thread_local.threshold_trials = threshold_trials
68-
# Counters accumulated across all attention calls in one forward pass.
69-
_thread_local.calibration_counters = None
70-
_thread_local.calibration_seq_k = None
71-
72-
73-
def clear_hf_triton_skip_softmax_config() -> None:
74-
"""Clear thread-local skip-softmax calibration config."""
75-
_thread_local.calibration_mode = False
76-
_thread_local.threshold_trials = None
77-
_thread_local.calibration_counters = None
78-
_thread_local.calibration_seq_k = None
79-
80-
81-
def get_calibration_counters() -> torch.Tensor | None:
82-
"""Return accumulated calibration counters ``[num_thresholds, 2]`` or None."""
83-
return getattr(_thread_local, "calibration_counters", None)
84-
85-
86-
def get_calibration_seq_k() -> int | None:
87-
"""Return KV sequence length observed during calibration, or None."""
88-
return getattr(_thread_local, "calibration_seq_k", None)
30+
# Skip-softmax calibration config and counters live on the module's
31+
# ``_sparse_method_instance`` (HF passes the owning module to
32+
# ``triton_attention_forward``), so no separate thread-local state is needed.
8933

9034

9135
def _seq_lens_from_mask(
@@ -165,29 +109,35 @@ def triton_attention_forward(
165109
kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
166110
kw["max_input_len_k"] = seq_k
167111

168-
# --- Calibration mode: collect multi-threshold tile-skip stats (prefill only) ---
169-
# Run the calibration kernel, which computes full (non-skipped) attention while
170-
# counting, per candidate threshold, how many KV tiles would be skipped. ``kw`` at
171-
# this point holds only the base attention args that ``attention_calibrate`` accepts;
172-
# the sparse-attention kwargs below are intentionally not added in this branch.
173-
calib_mode = getattr(_thread_local, "calibration_mode", False)
174-
if calib_mode and not is_decode:
175-
trials = getattr(_thread_local, "threshold_trials", None)
112+
# Sparse-attention method instance. It carries the inference threshold and,
113+
# during calibration, both the calibration config and the accumulated
114+
# tile-skip counters. Available here because HF passes the owning module.
115+
method = getattr(module, "_sparse_method_instance", None)
116+
117+
# Calibration mode: run the calibration kernel, which computes full attention
118+
# while counting, per candidate threshold, how many KV tiles would be skipped.
119+
# The sparse-attention kwargs below are intentionally not added in this branch.
120+
if method is not None and getattr(method, "_calibration_mode", False):
121+
trials = getattr(method, "_threshold_trials", None)
122+
# Deferred: the package __init__ imports this module, so importing
123+
# attention_calibrate at module top would be circular.
176124
from modelopt.torch.kernels.common.attention import attention_calibrate
177125

178126
if trials and attention_calibrate is not None:
179127
o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials)
180128

181129
# Accumulate counters across all attention calls in this forward pass.
182-
prev = getattr(_thread_local, "calibration_counters", None)
183-
_thread_local.calibration_counters = counters if prev is None else prev + counters
184-
_thread_local.calibration_seq_k = seq_k
130+
# The method instance is per-module so the accumulator stays on one
131+
# device, but guard the add against a device mismatch just in case.
132+
prev = getattr(method, "_hf_calibration_counters", None)
133+
method._hf_calibration_counters = (
134+
counters if prev is None else prev + counters.to(prev.device)
135+
)
136+
method._hf_calibration_seq_k = seq_k
137+
method._hf_calibration_is_decode = is_decode
185138

186139
return (o.view(batch, seq_len, num_heads, head_dim), None)
187140

188-
# Sparse attention params
189-
method = getattr(module, "_sparse_method_instance", None)
190-
191141
# N:M sparse softmax: prefill only (no perf benefit for decode)
192142
if method is not None and not is_decode and getattr(module, "_apply_sparse_nm", False):
193143
kw["sparsity_n"] = method.sparsity_n
@@ -233,10 +183,6 @@ def register_triton_attention() -> bool:
233183

234184

235185
__all__ = [
236-
"clear_hf_triton_skip_softmax_config",
237-
"get_calibration_counters",
238-
"get_calibration_seq_k",
239186
"register_triton_attention",
240-
"set_hf_triton_skip_softmax_config",
241187
"triton_attention_forward",
242188
]

modelopt/torch/kernels/common/attention/triton_fa.py

Lines changed: 96 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -919,23 +919,29 @@ def forward(
919919
def grid(META):
920920
return (batch, num_q_heads, triton.cdiv(max_input_len, META["BLOCK_M"]))
921921

922-
if do_measure:
923-
# Runtime counters mutate global tensors, so do not run them through
924-
# autotune candidate trials. Use one stable config for measurement.
925-
_attn_fwd.fn[grid](
926-
*fwd_args,
927-
**fwd_kwargs,
928-
BLOCK_M=_MEASURE_BLOCK_M,
929-
BLOCK_N=_MEASURE_BLOCK_N,
930-
num_warps=_MEASURE_NUM_WARPS,
931-
num_stages=_MEASURE_NUM_STAGES,
932-
)
933-
else:
934-
_attn_fwd[grid](
935-
*fwd_args,
936-
**fwd_kwargs,
937-
# BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune
938-
)
922+
# Triton launches on torch.cuda.current_device(), which is not
923+
# necessarily the device the tensors live on (e.g. under accelerate
924+
# device_map="auto" sharding). Activate the tensor's device so the
925+
# kernel dereferences the right pointers instead of triggering an
926+
# illegal memory access.
927+
with torch.cuda.device(q.device):
928+
if do_measure:
929+
# Runtime counters mutate global tensors, so do not run them through
930+
# autotune candidate trials. Use one stable config for measurement.
931+
_attn_fwd.fn[grid](
932+
*fwd_args,
933+
**fwd_kwargs,
934+
BLOCK_M=_MEASURE_BLOCK_M,
935+
BLOCK_N=_MEASURE_BLOCK_N,
936+
num_warps=_MEASURE_NUM_WARPS,
937+
num_stages=_MEASURE_NUM_STAGES,
938+
)
939+
else:
940+
_attn_fwd[grid](
941+
*fwd_args,
942+
**fwd_kwargs,
943+
# BLOCK_M, BLOCK_N, num_warps, num_stages chosen by autotune
944+
)
939945

940946
# Store sparsity counters on the output tensor for retrieval by callers
941947
if do_measure:
@@ -970,23 +976,30 @@ def backward(ctx, grad_output):
970976
do = grad_output.contiguous()
971977
num_warps = 4
972978

979+
# Triton launches on torch.cuda.current_device(), which is not
980+
# necessarily the device the tensors live on (e.g. under accelerate
981+
# device_map="auto" sharding). Activate the tensor's device for each
982+
# launch so the kernels dereference the right pointers instead of
983+
# triggering an illegal memory access.
984+
973985
# Phase 1: delta = rowsum(O * dO)
974986
delta = torch.empty_like(lse)
975-
_attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))](
976-
o,
977-
do,
978-
delta,
979-
o.stride(0),
980-
o.stride(1),
981-
do.stride(0),
982-
do.stride(1),
983-
delta.stride(0),
984-
delta.stride(1),
985-
q.shape[0],
986-
HEAD_DIM=HEAD_DIM,
987-
BLOCK_D=BLOCK_D,
988-
BLOCK_M=BLOCK,
989-
)
987+
with torch.cuda.device(q.device):
988+
_attn_bwd_preprocess[(ctx.num_q_heads, triton.cdiv(q.shape[0], BLOCK))](
989+
o,
990+
do,
991+
delta,
992+
o.stride(0),
993+
o.stride(1),
994+
do.stride(0),
995+
do.stride(1),
996+
delta.stride(0),
997+
delta.stride(1),
998+
q.shape[0],
999+
HEAD_DIM=HEAD_DIM,
1000+
BLOCK_D=BLOCK_D,
1001+
BLOCK_M=BLOCK,
1002+
)
9901003

9911004
dq = torch.zeros_like(q)
9921005
dk = torch.zeros_like(k)
@@ -1016,57 +1029,59 @@ def backward(ctx, grad_output):
10161029
)
10171030

10181031
# Phase 2: dK, dV
1019-
_attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))](
1020-
*bwd_args[:4],
1021-
dk,
1022-
dv,
1023-
*bwd_args[4:],
1024-
dk.stride(0),
1025-
dk.stride(1),
1026-
dv.stride(0),
1027-
dv.stride(1),
1028-
lse.stride(0),
1029-
lse.stride(1),
1030-
kv_group_num=ctx.kv_group_num,
1031-
BLOCK_M=BLOCK,
1032-
BLOCK_D=BLOCK_D,
1033-
BLOCK_N=BLOCK,
1034-
IS_CAUSAL=ctx.is_causal,
1035-
HEAD_DIM=HEAD_DIM,
1036-
SPARSITY_N=ctx.sparsity_n,
1037-
SPARSITY_M=ctx.sparsity_m,
1038-
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
1039-
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
1040-
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
1041-
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
1042-
num_warps=num_warps,
1043-
num_stages=1,
1044-
)
1032+
with torch.cuda.device(q.device):
1033+
_attn_bwd_dkdv[(ctx.batch, ctx.num_kv_heads, triton.cdiv(ctx.max_input_len_k, BLOCK))](
1034+
*bwd_args[:4],
1035+
dk,
1036+
dv,
1037+
*bwd_args[4:],
1038+
dk.stride(0),
1039+
dk.stride(1),
1040+
dv.stride(0),
1041+
dv.stride(1),
1042+
lse.stride(0),
1043+
lse.stride(1),
1044+
kv_group_num=ctx.kv_group_num,
1045+
BLOCK_M=BLOCK,
1046+
BLOCK_D=BLOCK_D,
1047+
BLOCK_N=BLOCK,
1048+
IS_CAUSAL=ctx.is_causal,
1049+
HEAD_DIM=HEAD_DIM,
1050+
SPARSITY_N=ctx.sparsity_n,
1051+
SPARSITY_M=ctx.sparsity_m,
1052+
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
1053+
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
1054+
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
1055+
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
1056+
num_warps=num_warps,
1057+
num_stages=1,
1058+
)
10451059

10461060
# Phase 3: dQ
1047-
_attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))](
1048-
*bwd_args[:4],
1049-
dq,
1050-
*bwd_args[4:],
1051-
dq.stride(0),
1052-
dq.stride(1),
1053-
lse.stride(0),
1054-
lse.stride(1),
1055-
kv_group_num=ctx.kv_group_num,
1056-
BLOCK_M=BLOCK,
1057-
BLOCK_D=BLOCK_D,
1058-
BLOCK_N=BLOCK,
1059-
IS_CAUSAL=ctx.is_causal,
1060-
HEAD_DIM=HEAD_DIM,
1061-
SPARSITY_N=ctx.sparsity_n,
1062-
SPARSITY_M=ctx.sparsity_m,
1063-
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
1064-
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
1065-
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
1066-
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
1067-
num_warps=num_warps,
1068-
num_stages=1,
1069-
)
1061+
with torch.cuda.device(q.device):
1062+
_attn_bwd_dq[(ctx.batch, ctx.num_q_heads, triton.cdiv(ctx.max_input_len, BLOCK))](
1063+
*bwd_args[:4],
1064+
dq,
1065+
*bwd_args[4:],
1066+
dq.stride(0),
1067+
dq.stride(1),
1068+
lse.stride(0),
1069+
lse.stride(1),
1070+
kv_group_num=ctx.kv_group_num,
1071+
BLOCK_M=BLOCK,
1072+
BLOCK_D=BLOCK_D,
1073+
BLOCK_N=BLOCK,
1074+
IS_CAUSAL=ctx.is_causal,
1075+
HEAD_DIM=HEAD_DIM,
1076+
SPARSITY_N=ctx.sparsity_n,
1077+
SPARSITY_M=ctx.sparsity_m,
1078+
DENSE_SINK_TOKENS=ctx.dense_sink_tokens,
1079+
DENSE_RECENT_TOKENS=ctx.dense_recent_tokens,
1080+
APPLY_SKIP_SOFTMAX=ctx.apply_skip,
1081+
SKIP_THRESHOLD_LOG2=ctx.skip_threshold_log2,
1082+
num_warps=num_warps,
1083+
num_stages=1,
1084+
)
10701085

10711086
return (
10721087
dq,

0 commit comments

Comments
 (0)