Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/llm_sparsity/attention_sparsity/hf_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from modelopt.torch.sparsity.attention_sparsity.config import (
SKIP_SOFTMAX_CALIB,
SKIP_SOFTMAX_CALIB_SPARSE24,
SKIP_SOFTMAX_TRITON_CALIB,
SPARSE_SOFTMAX_DEFAULT,
)
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
Expand All @@ -44,6 +45,7 @@
SPARSE_ATTN_CFG_CHOICES = {
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
"skip_softmax_calib_sparse24": SKIP_SOFTMAX_CALIB_SPARSE24,
"skip_softmax_triton_calib": SKIP_SOFTMAX_TRITON_CALIB,
"sparse_softmax": SPARSE_SOFTMAX_DEFAULT,
}

Expand Down Expand Up @@ -186,6 +188,15 @@ def main(args):
calib["max_seqlen"] = args.calib_max_seqlen
if args.calib_chunk_size is not None:
calib["chunk_size"] = args.calib_chunk_size
# Point RULER calibration at the data downloaded by download_ruler_data.sh
# (next to this script) unless the user overrides it. The NIAH essay
# haystack requires this directory.
calib.setdefault(
"data_dir",
args.calib_data_dir
if args.calib_data_dir is not None
else str(Path(__file__).parent / "data"),
)

model = mtsa.sparsify(model, config=sparse_config)
print("Sparse attention applied successfully!")
Expand Down Expand Up @@ -302,6 +313,14 @@ def main(args):
default=None,
help="Chunk size for calibration prefill. Overrides config value.",
)
parser.add_argument(
"--calib_data_dir",
type=str,
default=None,
help="Path to RULER calibration data (contains an 'essays' subdir). "
"Defaults to the 'data' directory next to this script "
"(populated by download_ruler_data.sh).",
)

args = parser.parse_args()
main(args)
84 changes: 84 additions & 0 deletions modelopt/torch/kernels/common/attention/hf_triton_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,71 @@

from __future__ import annotations

import threading

import torch
import torch.nn as nn

from modelopt.torch.kernels.common.attention.triton_fa import attention

# ---------------------------------------------------------------------------
# Thread-local skip-softmax calibration config for the HF (modelopt_triton) backend
# ---------------------------------------------------------------------------
# Mirrors the diffusers/LTX backends: during calibration the Triton calibration
# kernel measures multi-threshold tile-skip statistics without skipping any tiles.
# Inference-time config (skip threshold / scale factor) is still read from the
# module/method attributes in ``triton_attention_forward`` — only calibration
# state lives here.
_thread_local = threading.local()


def set_hf_triton_skip_softmax_config(
threshold: float | None = None,
calibration_mode: bool = False,
threshold_trials: list[float] | None = None,
scale_factor: float | None = None,
measure_sparsity: bool = False,
) -> None:
"""Set thread-local skip-softmax calibration config for the next forward.

Accepts the same keyword arguments as the diffusers/LTX backends so the
shared :class:`TritonSkipSoftmaxMethod` can configure all backends uniformly.
Only the calibration fields are consumed by the HF backend; the inference
fields (``threshold``/``scale_factor``/``measure_sparsity``) are accepted for
signature compatibility but ignored here, since the HF inference path reads
its threshold from the module/method attributes.

Args:
threshold: Ignored by the HF backend (inference threshold comes from the module).
calibration_mode: If True, route prefill attention through the calibration kernel.
threshold_trials: Thresholds to measure sparsity for (used when calibration_mode=True).
scale_factor: Ignored by the HF backend.
measure_sparsity: Ignored by the HF backend.
"""
_thread_local.calibration_mode = calibration_mode
_thread_local.threshold_trials = threshold_trials
# Counters accumulated across all attention calls in one forward pass.
_thread_local.calibration_counters = None
_thread_local.calibration_seq_k = None


def clear_hf_triton_skip_softmax_config() -> None:
"""Clear thread-local skip-softmax calibration config."""
_thread_local.calibration_mode = False
_thread_local.threshold_trials = None
_thread_local.calibration_counters = None
_thread_local.calibration_seq_k = None


def get_calibration_counters() -> torch.Tensor | None:
"""Return accumulated calibration counters ``[num_thresholds, 2]`` or None."""
return getattr(_thread_local, "calibration_counters", None)


def get_calibration_seq_k() -> int | None:
"""Return KV sequence length observed during calibration, or None."""
return getattr(_thread_local, "calibration_seq_k", None)


def _seq_lens_from_mask(
attention_mask: torch.Tensor | None,
Expand Down Expand Up @@ -105,6 +165,26 @@ def triton_attention_forward(
kw["b_seq_len_k"] = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
kw["max_input_len_k"] = seq_k

# --- Calibration mode: collect multi-threshold tile-skip stats (prefill only) ---
# Run the calibration kernel, which computes full (non-skipped) attention while
# counting, per candidate threshold, how many KV tiles would be skipped. ``kw`` at
# this point holds only the base attention args that ``attention_calibrate`` accepts;
# the sparse-attention kwargs below are intentionally not added in this branch.
calib_mode = getattr(_thread_local, "calibration_mode", False)
if calib_mode and not is_decode:
trials = getattr(_thread_local, "threshold_trials", None)
from modelopt.torch.kernels.common.attention import attention_calibrate

if trials and attention_calibrate is not None:
o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials)

# Accumulate counters across all attention calls in this forward pass.
prev = getattr(_thread_local, "calibration_counters", None)
_thread_local.calibration_counters = counters if prev is None else prev + counters
_thread_local.calibration_seq_k = seq_k

return (o.view(batch, seq_len, num_heads, head_dim), None)

# Sparse attention params
method = getattr(module, "_sparse_method_instance", None)

Expand Down Expand Up @@ -153,6 +233,10 @@ def register_triton_attention() -> bool:


__all__ = [
"clear_hf_triton_skip_softmax_config",
"get_calibration_counters",
"get_calibration_seq_k",
"register_triton_attention",
"set_hf_triton_skip_softmax_config",
"triton_attention_forward",
]
31 changes: 31 additions & 0 deletions modelopt/torch/sparsity/attention_sparsity/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,36 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
}


# RULER calibration via the fused Triton calibration kernel (prefill only).
# Computes the same exponential-model calibration as SKIP_SOFTMAX_CALIB but
# measures tile-skip statistics with the Triton ``attention_calibrate`` kernel
# (the way the Triton inference kernel actually skips tiles) instead of the
# PyTorch F.softmax-patching block path. Faster on GPU since it avoids
# materializing per-block tensors.
SKIP_SOFTMAX_TRITON_CALIB = {
"sparse_cfg": {
"calibration": {
# Prefill only: omitting "decode" leaves its target at 0.0, which
# skips decode calibration (the Triton calibration kernel is
# prefill-oriented).
"target_sparse_ratio": {"prefill": 0.5},
"samples": 64,
"max_seqlen": 16384,
# Full prefill (seq_q == seq_k, uniform batch=1) — what
# attention_calibrate was validated against. Chunked prefill would
# exercise an untested KV-cache causal-offset path in the kernel.
"chunk_size": -1,
},
"*attn*": {
"method": "triton_skip_softmax",
"backend": "triton",
"enable": True,
},
"default": {"enable": False},
},
}


class VSAAttributeConfig(ModeloptBaseConfig):
"""Video Sparse Attention (VSA) attribute configuration.

Expand Down Expand Up @@ -738,6 +768,7 @@ class VSAConfig(SparseAttentionConfig):
"SKIP_SOFTMAX_CALIB",
"SKIP_SOFTMAX_CALIB_SPARSE24",
"SKIP_SOFTMAX_DEFAULT",
"SKIP_SOFTMAX_TRITON_CALIB",
"SKIP_SOFTMAX_TRITON_DEFAULT",
"SPARSE_SOFTMAX_DEFAULT",
"VSA_DEFAULT",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _get_diffusers_backend_context():
yield

def _set_triton_backends(self, **kwargs):
"""Set config on both diffusers and LTX Triton backends."""
"""Set config on the diffusers, LTX, and HF (modelopt_triton) Triton backends."""
try:
from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import (
set_triton_skip_softmax_config,
Expand All @@ -187,9 +187,17 @@ def _set_triton_backends(self, **kwargs):
set_ltx_triton_context(active=True, **kwargs)
except ImportError:
pass
try:
from modelopt.torch.kernels.common.attention.hf_triton_attention import (
set_hf_triton_skip_softmax_config,
)

set_hf_triton_skip_softmax_config(**kwargs)
except ImportError:
pass

def _clear_triton_backends(self):
"""Clear config on both Triton backends."""
"""Clear config on the diffusers, LTX, and HF Triton backends."""
try:
from modelopt.torch.kernels.sparsity.attention.diffusers_triton_attention import (
clear_triton_skip_softmax_config,
Expand All @@ -206,6 +214,14 @@ def _clear_triton_backends(self):
clear_ltx_triton_context()
except ImportError:
pass
try:
from modelopt.torch.kernels.common.attention.hf_triton_attention import (
clear_hf_triton_skip_softmax_config,
)

clear_hf_triton_skip_softmax_config()
except ImportError:
pass

def _collect_calibration_stats(self, module):
"""Read Triton calibration counters and store as stats on the module."""
Expand Down Expand Up @@ -235,6 +251,18 @@ def _collect_calibration_stats(self, module):
except ImportError:
pass

if counters is None:
try:
from modelopt.torch.kernels.common.attention.hf_triton_attention import (
get_calibration_counters,
get_calibration_seq_k,
)

counters = get_calibration_counters()
seq_k = get_calibration_seq_k()
except ImportError:
pass

if counters is None or self._threshold_trials is None:
return

Expand Down
Loading
Loading