@@ -133,20 +133,20 @@ def _triton_calibration_context(self, module):
133133 module ._apply_skip_softmax = False
134134 self ._clear_triton_backends ()
135135
136- def _get_scale_factor (self ) -> float | None :
137- """Compute scale_factor from calibration params, or None if uncalibrated .
136+ def _get_scale_factor (self , phase : str = "prefill" ) -> float | None :
137+ """Compute the scale_factor for ``phase`` from calibration params, or None.
138138
139- The scale_factor is sequence-length-independent. Backends divide by the
139+ The scale_factor is sequence-length-independent. Callers divide by the
140140 actual ``seq_k`` at call time: ``threshold = scale_factor / seq_k``.
141141 """
142142 if self .calibration_params and self .target_sparse_ratio :
143143 import math
144144 import warnings
145145
146- params = self .calibration_params .get ("prefill" , {})
146+ params = self .calibration_params .get (phase , {})
147147 a = params .get ("a" , 0 )
148148 b = params .get ("b" , 0 )
149- target = self .target_sparse_ratio .get ("prefill" , 0.5 )
149+ target = self .target_sparse_ratio .get (phase , 0.5 )
150150 if a > 0 and b > 0 :
151151 # Warn if target is outside the calibrated range
152152 min_s = params .get ("min_observed_sparsity" )
@@ -167,6 +167,22 @@ def _get_scale_factor(self) -> float | None:
167167 return a * math .exp (b * target )
168168 return None
169169
170+ def get_inference_threshold (self , seq_q : int , seq_k : int ) -> float | None :
171+ """Return the skip threshold to apply for this call's phase.
172+
173+ Picks the phase from the query length (``decode`` when ``seq_q == 1``,
174+ else ``prefill``) and returns the calibrated dynamic threshold
175+ ``scale_factor(phase) / seq_k`` when the phase is calibrated, otherwise
176+ the static ``skip_softmax_threshold`` (or ``None`` to disable). This is
177+ what the HF backend applies; it keeps prefill and decode on their own
178+ calibrated ``(a, b)`` instead of forcing decode onto prefill's.
179+ """
180+ phase = "decode" if seq_q <= 1 else "prefill"
181+ scale_factor = self ._get_scale_factor (phase )
182+ if scale_factor is not None and seq_k > 0 :
183+ return scale_factor / seq_k
184+ return self .skip_softmax_threshold or None
185+
170186 @staticmethod
171187 @contextmanager
172188 def _get_diffusers_backend_context ():
0 commit comments