Skip to content

Commit 37b77a0

Browse files
kaix-nvrohansjoshi
authored andcommitted
Apply per-phase calibrated skip threshold at HF inference
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent 5ca71ad commit 37b77a0

2 files changed

Lines changed: 27 additions & 8 deletions

File tree

modelopt/torch/kernels/common/attention/hf_triton_attention.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,13 @@ def triton_attention_forward(
145145
kw["dense_sink_tokens"] = method.dense_sink_tokens
146146
kw["dense_recent_tokens"] = method.dense_recent_tokens
147147

148-
# Skip-softmax: applies to both prefill and decode
148+
# Skip-softmax: applies to both prefill and decode. Prefer the method's
149+
# per-phase calibrated dynamic threshold (scale_factor / seq_k); fall back
150+
# to the static threshold when uncalibrated.
149151
if method is not None and getattr(module, "_apply_skip_softmax", False):
150-
if method.skip_softmax_threshold:
151-
kw["skip_softmax_threshold"] = method.skip_softmax_threshold
152+
threshold = method.get_inference_threshold(seq_len, seq_k)
153+
if threshold:
154+
kw["skip_softmax_threshold"] = threshold
152155

153156
o = attention(q, k, v, **kw)
154157

modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,20 +133,20 @@ def _triton_calibration_context(self, module):
133133
module._apply_skip_softmax = False
134134
self._clear_triton_backends()
135135

136-
def _get_scale_factor(self) -> float | None:
137-
"""Compute scale_factor from calibration params, or None if uncalibrated.
136+
def _get_scale_factor(self, phase: str = "prefill") -> float | None:
137+
"""Compute the scale_factor for ``phase`` from calibration params, or None.
138138
139-
The scale_factor is sequence-length-independent. Backends divide by the
139+
The scale_factor is sequence-length-independent. Callers divide by the
140140
actual ``seq_k`` at call time: ``threshold = scale_factor / seq_k``.
141141
"""
142142
if self.calibration_params and self.target_sparse_ratio:
143143
import math
144144
import warnings
145145

146-
params = self.calibration_params.get("prefill", {})
146+
params = self.calibration_params.get(phase, {})
147147
a = params.get("a", 0)
148148
b = params.get("b", 0)
149-
target = self.target_sparse_ratio.get("prefill", 0.5)
149+
target = self.target_sparse_ratio.get(phase, 0.5)
150150
if a > 0 and b > 0:
151151
# Warn if target is outside the calibrated range
152152
min_s = params.get("min_observed_sparsity")
@@ -167,6 +167,22 @@ def _get_scale_factor(self) -> float | None:
167167
return a * math.exp(b * target)
168168
return None
169169

170+
def get_inference_threshold(self, seq_q: int, seq_k: int) -> float | None:
171+
"""Return the skip threshold to apply for this call's phase.
172+
173+
Picks the phase from the query length (``decode`` when ``seq_q == 1``,
174+
else ``prefill``) and returns the calibrated dynamic threshold
175+
``scale_factor(phase) / seq_k`` when the phase is calibrated, otherwise
176+
the static ``skip_softmax_threshold`` (or ``None`` to disable). This is
177+
what the HF backend applies; it keeps prefill and decode on their own
178+
calibrated ``(a, b)`` instead of forcing decode onto prefill's.
179+
"""
180+
phase = "decode" if seq_q <= 1 else "prefill"
181+
scale_factor = self._get_scale_factor(phase)
182+
if scale_factor is not None and seq_k > 0:
183+
return scale_factor / seq_k
184+
return self.skip_softmax_threshold or None
185+
170186
@staticmethod
171187
@contextmanager
172188
def _get_diffusers_backend_context():

0 commit comments

Comments
 (0)