|
22 | 22 |
|
23 | 23 | from __future__ import annotations |
24 | 24 |
|
25 | | -import threading |
26 | | - |
27 | 25 | import torch |
28 | 26 | import torch.nn as nn |
29 | 27 |
|
30 | 28 | from modelopt.torch.kernels.common.attention.triton_fa import attention |
31 | 29 |
|
32 | | -# --------------------------------------------------------------------------- |
33 | | -# Thread-local skip-softmax calibration config for the HF (modelopt_triton) backend |
34 | | -# --------------------------------------------------------------------------- |
35 | | -# Mirrors the diffusers/LTX backends: during calibration the Triton calibration |
36 | | -# kernel measures multi-threshold tile-skip statistics without skipping any tiles. |
37 | | -# Inference-time config (skip threshold / scale factor) is still read from the |
38 | | -# module/method attributes in ``triton_attention_forward`` — only calibration |
39 | | -# state lives here. |
40 | | -_thread_local = threading.local() |
41 | | - |
42 | | - |
43 | | -def set_hf_triton_skip_softmax_config( |
44 | | - threshold: float | None = None, |
45 | | - calibration_mode: bool = False, |
46 | | - threshold_trials: list[float] | None = None, |
47 | | - scale_factor: float | None = None, |
48 | | - measure_sparsity: bool = False, |
49 | | -) -> None: |
50 | | - """Set thread-local skip-softmax calibration config for the next forward. |
51 | | -
|
52 | | - Accepts the same keyword arguments as the diffusers/LTX backends so the |
53 | | - shared :class:`TritonSkipSoftmaxMethod` can configure all backends uniformly. |
54 | | - Only the calibration fields are consumed by the HF backend; the inference |
55 | | - fields (``threshold``/``scale_factor``/``measure_sparsity``) are accepted for |
56 | | - signature compatibility but ignored here, since the HF inference path reads |
57 | | - its threshold from the module/method attributes. |
58 | | -
|
59 | | - Args: |
60 | | - threshold: Ignored by the HF backend (inference threshold comes from the module). |
61 | | - calibration_mode: If True, route prefill attention through the calibration kernel. |
62 | | - threshold_trials: Thresholds to measure sparsity for (used when calibration_mode=True). |
63 | | - scale_factor: Ignored by the HF backend. |
64 | | - measure_sparsity: Ignored by the HF backend. |
65 | | - """ |
66 | | - _thread_local.calibration_mode = calibration_mode |
67 | | - _thread_local.threshold_trials = threshold_trials |
68 | | - # Counters accumulated across all attention calls in one forward pass. |
69 | | - _thread_local.calibration_counters = None |
70 | | - _thread_local.calibration_seq_k = None |
71 | | - |
72 | | - |
73 | | -def clear_hf_triton_skip_softmax_config() -> None: |
74 | | - """Clear thread-local skip-softmax calibration config.""" |
75 | | - _thread_local.calibration_mode = False |
76 | | - _thread_local.threshold_trials = None |
77 | | - _thread_local.calibration_counters = None |
78 | | - _thread_local.calibration_seq_k = None |
79 | | - |
80 | | - |
81 | | -def get_calibration_counters() -> torch.Tensor | None: |
82 | | - """Return accumulated calibration counters ``[num_thresholds, 2]`` or None.""" |
83 | | - return getattr(_thread_local, "calibration_counters", None) |
84 | | - |
85 | | - |
86 | | -def get_calibration_seq_k() -> int | None: |
87 | | - """Return KV sequence length observed during calibration, or None.""" |
88 | | - return getattr(_thread_local, "calibration_seq_k", None) |
| 30 | +# Skip-softmax calibration config and counters live on the module's |
| 31 | +# ``_sparse_method_instance`` (HF passes the owning module to |
| 32 | +# ``triton_attention_forward``), so no separate thread-local state is needed. |
89 | 33 |
|
90 | 34 |
|
91 | 35 | def _seq_lens_from_mask( |
@@ -165,29 +109,35 @@ def triton_attention_forward( |
165 | 109 | kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32) |
166 | 110 | kw["max_input_len_k"] = seq_k |
167 | 111 |
|
168 | | - # --- Calibration mode: collect multi-threshold tile-skip stats (prefill only) --- |
169 | | - # Run the calibration kernel, which computes full (non-skipped) attention while |
170 | | - # counting, per candidate threshold, how many KV tiles would be skipped. ``kw`` at |
171 | | - # this point holds only the base attention args that ``attention_calibrate`` accepts; |
172 | | - # the sparse-attention kwargs below are intentionally not added in this branch. |
173 | | - calib_mode = getattr(_thread_local, "calibration_mode", False) |
174 | | - if calib_mode and not is_decode: |
175 | | - trials = getattr(_thread_local, "threshold_trials", None) |
| 112 | + # Sparse-attention method instance. It carries the inference threshold and, |
| 113 | + # during calibration, both the calibration config and the accumulated |
| 114 | + # tile-skip counters. Available here because HF passes the owning module. |
| 115 | + method = getattr(module, "_sparse_method_instance", None) |
| 116 | + |
| 117 | + # Calibration mode: run the calibration kernel, which computes full attention |
| 118 | + # while counting, per candidate threshold, how many KV tiles would be skipped. |
| 119 | + # The sparse-attention kwargs below are intentionally not added in this branch. |
| 120 | + if method is not None and getattr(method, "_calibration_mode", False): |
| 121 | + trials = getattr(method, "_threshold_trials", None) |
| 122 | + # Deferred: the package __init__ imports this module, so importing |
| 123 | + # attention_calibrate at module top would be circular. |
176 | 124 | from modelopt.torch.kernels.common.attention import attention_calibrate |
177 | 125 |
|
178 | 126 | if trials and attention_calibrate is not None: |
179 | 127 | o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials) |
180 | 128 |
|
181 | 129 | # Accumulate counters across all attention calls in this forward pass. |
182 | | - prev = getattr(_thread_local, "calibration_counters", None) |
183 | | - _thread_local.calibration_counters = counters if prev is None else prev + counters |
184 | | - _thread_local.calibration_seq_k = seq_k |
| 130 | + # The method instance is per-module so the accumulator stays on one |
| 131 | + # device, but guard the add against a device mismatch just in case. |
| 132 | + prev = getattr(method, "_hf_calibration_counters", None) |
| 133 | + method._hf_calibration_counters = ( |
| 134 | + counters if prev is None else prev + counters.to(prev.device) |
| 135 | + ) |
| 136 | + method._hf_calibration_seq_k = seq_k |
| 137 | + method._hf_calibration_is_decode = is_decode |
185 | 138 |
|
186 | 139 | return (o.view(batch, seq_len, num_heads, head_dim), None) |
187 | 140 |
|
188 | | - # Sparse attention params |
189 | | - method = getattr(module, "_sparse_method_instance", None) |
190 | | - |
191 | 141 | # N:M sparse softmax: prefill only (no perf benefit for decode) |
192 | 142 | if method is not None and not is_decode and getattr(module, "_apply_sparse_nm", False): |
193 | 143 | kw["sparsity_n"] = method.sparsity_n |
@@ -233,10 +183,6 @@ def register_triton_attention() -> bool: |
233 | 183 |
|
234 | 184 |
|
235 | 185 | __all__ = [ |
236 | | - "clear_hf_triton_skip_softmax_config", |
237 | | - "get_calibration_counters", |
238 | | - "get_calibration_seq_k", |
239 | 186 | "register_triton_attention", |
240 | | - "set_hf_triton_skip_softmax_config", |
241 | 187 | "triton_attention_forward", |
242 | 188 | ] |
0 commit comments