Skip to content

Commit 68f63b6

Browse files
yeyu-nvidiaclaude
andcommitted
Build SageAttention as standalone quantization feature
Move NVFP4 P-matrix quantization (quantize_p) out of the sparsity module and into a new modelopt/torch/quantization/sage_attention/ module. Key changes: - Add modelopt/torch/quantization/sage_attention/__init__.py with apply_sage_attention(transformer) API exposed via mtq namespace. Wraps the transformer forward to activate the modelopt_triton diffusers backend and set quantize_p=True in thread-local for every call. - Remove quantize_p from SparseAttentionAttributeConfig (config.py), TritonSkipSoftmaxMethod, and TritonSparseSoftmaxMethod — sparsity methods no longer control quantization. - Split thread-local management in diffusers_triton_attention.py: * set_triton_skip_softmax_config() no longer accepts quantize_p * clear_triton_skip_softmax_config() does NOT reset quantize_p * New set_sage_attention_config() / clear_sage_attention_config() manage quantize_p independently This enables transparent composition: apply_sage_attention() sets quantize_p=True at the outer forward level; per-layer sparsity contexts clear only their own params without clobbering quantize_p. - Delete plugins/diffusers.py (WanSparseAttentionModule) — superseded by PR #1166's diffusers_triton_attention.py backend approach. - Update wan2_sage_attention.py example: apply_triton_sparse_kernel() no longer accepts quantize_p; --quantize-p now calls apply_sage_attention() from modelopt.torch.quantization. - Update tests to reflect the new API boundaries. Usage: from modelopt.torch.quantization import apply_sage_attention apply_sage_attention(pipe.transformer) # standalone # combined with N:M sparse softmax: mtsa.sparsify(transformer, mtsa.SPARSE_SOFTMAX_DEFAULT) apply_sage_attention(transformer) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 3f0bfd3 commit 68f63b6

File tree

9 files changed

+475
-701
lines changed

9 files changed

+475
-701
lines changed

examples/diffusers/quantization/wan2_sage_attention.py

Lines changed: 65 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,22 @@
3636
3737
``triton-sparse`` (requires triton + modelopt)
3838
ModelOpt Triton flash-attention kernel with N:M sparse softmax (2:4 by default).
39-
Applied via ``mtsa.sparsify()`` to the WAN transformer using the ``diffusers_triton``
40-
backend. For every 4 K positions, keeps top-2 attention scores; the other 2 are
41-
set to -inf before softmax. Uses WanSparseAttentionModule from modelopt.
39+
Applied via ``mtsa.sparsify()`` to the WAN transformer using the ``triton``
40+
backend with the modelopt_triton diffusers attention backend. For every 4 K
41+
positions, keeps top-2 attention scores; the other 2 are set to -inf before
42+
softmax.
4243
4344
``triton-skip`` (requires triton + modelopt)
4445
ModelOpt Triton flash-attention kernel with skip-softmax tile pruning.
4546
Tiles whose attention mass is below a threshold (default 0.1) are skipped entirely.
46-
Applied via ``mtsa.sparsify()`` using the ``diffusers_triton`` backend.
47+
Applied via ``mtsa.sparsify()`` using the ``triton`` backend with the modelopt_triton
48+
diffusers attention backend.
4749
48-
``triton-sparse-nvfp4`` (requires triton + modelopt)
49-
ModelOpt Triton flash-attention with N:M sparse softmax (2:4) AND NVFP4 E2M1
50-
P-matrix quantization in a single fused Triton kernel pass. Per-tile scaling
51-
(one scale per BLOCK_M×BLOCK_N tile) — finer granularity than a Python
52-
post-softmax approach. Combines sparsity and quantization in one pass.
53-
54-
``triton-skip-nvfp4`` (requires triton + modelopt)
55-
ModelOpt Triton flash-attention with skip-softmax tile pruning AND NVFP4 E2M1
56-
P-matrix quantization in a single fused Triton kernel pass. Skipped tiles
57-
contribute nothing and are never quantized.
50+
NVFP4 P-matrix quantization (``--quantize-p``) is a **SageAttention** feature —
51+
an independent quantization pass applied via
52+
``modelopt.torch.quantization.apply_sage_attention()``. It quantizes the
53+
post-softmax P tile to NVFP4 E2M1 inside the Triton kernel (per-tile max scaling).
54+
``--quantize-p`` can be combined with any Triton sparse kernel or used standalone.
5855
5956
Requirements::
6057
@@ -84,10 +81,10 @@
8481
python wan2_sage_attention.py --prompt "..." --kernel triton-skip
8582
8683
# ModelOpt Triton sparse + NVFP4 P-matrix quantization
87-
python wan2_sage_attention.py --prompt "..." --kernel triton-sparse-nvfp4
84+
python wan2_sage_attention.py --prompt "..." --kernel triton-sparse --quantize-p
8885
8986
# ModelOpt Triton skip-softmax + NVFP4 P-matrix quantization
90-
python wan2_sage_attention.py --prompt "..." --kernel triton-skip-nvfp4
87+
python wan2_sage_attention.py --prompt "..." --kernel triton-skip --quantize-p
9188
9289
# Smaller 5B model (fits on a single 24 GB GPU)
9390
python wan2_sage_attention.py \\
@@ -118,25 +115,19 @@
118115
KERNEL_SAGE2_FP8 = "sage2-fp8"
119116
KERNEL_TRITON_SPARSE = "triton-sparse"
120117
KERNEL_TRITON_SKIP = "triton-skip"
121-
KERNEL_TRITON_SPARSE_NVFP4 = "triton-sparse-nvfp4"
122-
KERNEL_TRITON_SKIP_NVFP4 = "triton-skip-nvfp4"
123118
KERNEL_CHOICES = [
124119
KERNEL_FP8,
125120
KERNEL_SAGE1,
126121
KERNEL_SAGE2_FP16,
127122
KERNEL_SAGE2_FP8,
128123
KERNEL_TRITON_SPARSE,
129124
KERNEL_TRITON_SKIP,
130-
KERNEL_TRITON_SPARSE_NVFP4,
131-
KERNEL_TRITON_SKIP_NVFP4,
132125
]
133126

134127
# Kernels that modify pipe.transformer in-place via ModelOpt APIs (not SDPA patching).
135128
_TRITON_MODELOPT_KERNELS = {
136129
KERNEL_TRITON_SPARSE,
137130
KERNEL_TRITON_SKIP,
138-
KERNEL_TRITON_SPARSE_NVFP4,
139-
KERNEL_TRITON_SKIP_NVFP4,
140131
}
141132

142133
_KERNEL_DESCRIPTIONS = {
@@ -146,8 +137,6 @@
146137
KERNEL_SAGE2_FP8: "sageattn_qk_int8_pv_fp8_cuda (SA2++, INT8 QK + FP8 PV, fp32+fp16 accum)",
147138
KERNEL_TRITON_SPARSE: "ModelOpt Triton flash-attn + N:M sparse softmax (2:4) via mtsa.sparsify()",
148139
KERNEL_TRITON_SKIP: "ModelOpt Triton flash-attn + skip-softmax tile pruning via mtsa.sparsify()",
149-
KERNEL_TRITON_SPARSE_NVFP4: "ModelOpt Triton flash-attn + 2:4 sparse softmax + NVFP4 P-matrix quantization",
150-
KERNEL_TRITON_SKIP_NVFP4: "ModelOpt Triton flash-attn + skip-softmax tile pruning + NVFP4 P-matrix quantization",
151140
}
152141

153142
# SageAttention CUDA kernel support by GPU compute capability:
@@ -222,8 +211,6 @@ def _detect_available_kernels() -> list[str]:
222211

223212
available.append(KERNEL_TRITON_SPARSE)
224213
available.append(KERNEL_TRITON_SKIP)
225-
available.append(KERNEL_TRITON_SPARSE_NVFP4)
226-
available.append(KERNEL_TRITON_SKIP_NVFP4)
227214
except ImportError:
228215
pass
229216

@@ -443,7 +430,7 @@ def attention_kernel_ctx(kernel: str = KERNEL_FP8):
443430
"sparsity_m": 4,
444431
"num_sink_tokens": 0,
445432
"dense_window_size": 0,
446-
"backend": "diffusers_triton",
433+
"backend": "triton",
447434
"enable": True,
448435
},
449436
"default": {"enable": False},
@@ -457,36 +444,7 @@ def attention_kernel_ctx(kernel: str = KERNEL_FP8):
457444
"*": {
458445
"method": "triton_skip_softmax",
459446
"skip_softmax_threshold": _TRITON_SKIP_DEFAULT_THRESHOLD,
460-
"backend": "diffusers_triton",
461-
"enable": True,
462-
},
463-
"default": {"enable": False},
464-
}
465-
}
466-
467-
_TRITON_SPARSE_NVFP4_CONFIG = {
468-
"sparse_cfg": {
469-
"*": {
470-
"method": "triton_sparse_softmax",
471-
"sparsity_n": 2,
472-
"sparsity_m": 4,
473-
"num_sink_tokens": 0,
474-
"dense_window_size": 0,
475-
"backend": "diffusers_triton",
476-
"quantize_p": True,
477-
"enable": True,
478-
},
479-
"default": {"enable": False},
480-
}
481-
}
482-
483-
_TRITON_SKIP_NVFP4_CONFIG = {
484-
"sparse_cfg": {
485-
"*": {
486-
"method": "triton_skip_softmax",
487-
"skip_softmax_threshold": _TRITON_SKIP_DEFAULT_THRESHOLD,
488-
"backend": "diffusers_triton",
489-
"quantize_p": True,
447+
"backend": "triton",
490448
"enable": True,
491449
},
492450
"default": {"enable": False},
@@ -496,8 +454,6 @@ def attention_kernel_ctx(kernel: str = KERNEL_FP8):
496454
_TRITON_KERNEL_CONFIGS = {
497455
KERNEL_TRITON_SPARSE: _TRITON_SPARSE_CONFIG,
498456
KERNEL_TRITON_SKIP: _TRITON_SKIP_CONFIG,
499-
KERNEL_TRITON_SPARSE_NVFP4: _TRITON_SPARSE_NVFP4_CONFIG,
500-
KERNEL_TRITON_SKIP_NVFP4: _TRITON_SKIP_NVFP4_CONFIG,
501457
}
502458

503459

@@ -508,12 +464,12 @@ def apply_triton_sparse_kernel(
508464
) -> None:
509465
"""Apply a ModelOpt Triton sparse attention kernel to the WAN transformer.
510466
511-
Calls ``mtsa.sparsify()`` with ``backend="diffusers_triton"``, which installs
512-
a ``ModelOptWanAttnProcessor`` on every ``WanAttention`` module. The NVFP4
513-
variants additionally pass ``quantize_p=True`` to the Triton kernel, enabling
514-
per-tile NVFP4 E2M1 P-matrix quantization in a single fused pass.
467+
Calls ``mtsa.sparsify()`` with the ``triton`` backend, which activates the
468+
modelopt_triton diffusers attention backend for every attention forward pass.
515469
516-
This modifies the model in-place.
470+
This modifies the model in-place. To additionally apply NVFP4 P-matrix
471+
quantization (SageAttention), call ``apply_sage_attention(transformer)``
472+
**after** this function.
517473
518474
Args:
519475
transformer: The ``pipe.transformer`` WAN model.
@@ -527,13 +483,13 @@ def apply_triton_sparse_kernel(
527483
import modelopt.torch.sparsity.attention_sparsity as mtsa
528484

529485
config = copy.deepcopy(_TRITON_KERNEL_CONFIGS[kernel])
530-
if skip_threshold is not None and kernel in (KERNEL_TRITON_SKIP, KERNEL_TRITON_SKIP_NVFP4):
486+
if skip_threshold is not None and kernel == KERNEL_TRITON_SKIP:
531487
config["sparse_cfg"]["*"]["skip_softmax_threshold"] = skip_threshold
532488

533489
mtsa.sparsify(transformer, config)
534490
thr = config["sparse_cfg"].get("*", {}).get("skip_softmax_threshold", "n/a")
535491
print(f"[Attention] Applied {kernel}: {_KERNEL_DESCRIPTIONS[kernel]}")
536-
if kernel in (KERNEL_TRITON_SKIP, KERNEL_TRITON_SKIP_NVFP4):
492+
if kernel == KERNEL_TRITON_SKIP:
537493
print(f"[Attention] skip_softmax_threshold={thr}")
538494

539495

@@ -763,9 +719,19 @@ def parse_args() -> argparse.Namespace:
763719
"sage2-fp16: SA2 INT8+FP16; "
764720
"sage2-fp8: SA2++ INT8+FP8; "
765721
"triton-sparse: ModelOpt Triton 2:4 N:M sparse softmax (requires triton + modelopt); "
766-
"triton-skip: ModelOpt Triton skip-softmax tile pruning (requires triton + modelopt); "
767-
"triton-sparse-nvfp4: triton-sparse + NVFP4 P-matrix quantization in one Triton pass; "
768-
"triton-skip-nvfp4: triton-skip + NVFP4 P-matrix quantization in one Triton pass"
722+
"triton-skip: ModelOpt Triton skip-softmax tile pruning (requires triton + modelopt)"
723+
),
724+
)
725+
parser.add_argument(
726+
"--quantize-p",
727+
action="store_true",
728+
default=False,
729+
help=(
730+
"Apply SageAttention NVFP4 E2M1 P-matrix quantization via "
731+
"modelopt.torch.quantization.apply_sage_attention(). "
732+
"Quantizes the post-softmax P tile inside the Triton kernel (per-tile max scaling). "
733+
"Can be used standalone or combined with any Triton sparse kernel: "
734+
"--kernel triton-sparse --quantize-p"
769735
),
770736
)
771737
parser.add_argument(
@@ -792,7 +758,7 @@ def parse_args() -> argparse.Namespace:
792758
default=None,
793759
metavar="LAMBDA",
794760
help=(
795-
"Override skip_softmax_threshold for triton-skip / triton-skip-nvfp4 kernels. "
761+
"Override skip_softmax_threshold for the triton-skip kernel. "
796762
f"Default: {_TRITON_SKIP_DEFAULT_THRESHOLD}. "
797763
"A tile is skipped when exp(tile_max - running_max) < LAMBDA "
798764
"(equivalently: tile_max < running_max + log(LAMBDA)). "
@@ -822,7 +788,15 @@ def main() -> None:
822788

823789
# --- Quantized ---
824790
if args.kernel in _TRITON_MODELOPT_KERNELS:
825-
apply_triton_sparse_kernel(pipe.transformer, args.kernel, skip_threshold=args.skip_threshold)
791+
apply_triton_sparse_kernel(
792+
pipe.transformer,
793+
args.kernel,
794+
skip_threshold=args.skip_threshold,
795+
)
796+
if args.quantize_p:
797+
from modelopt.torch.quantization import apply_sage_attention
798+
799+
apply_sage_attention(pipe.transformer)
826800
else:
827801
enable_attention_kernel(args.kernel)
828802
_, frames_quant = run_inference(pipe, args, label=args.kernel)
@@ -896,14 +870,29 @@ def main() -> None:
896870
run_inference(pipe, args, label="baseline")
897871

898872
elif args.kernel in _TRITON_MODELOPT_KERNELS:
899-
apply_triton_sparse_kernel(pipe.transformer, args.kernel, skip_threshold=args.skip_threshold)
873+
apply_triton_sparse_kernel(
874+
pipe.transformer,
875+
args.kernel,
876+
skip_threshold=args.skip_threshold,
877+
)
878+
if args.quantize_p:
879+
from modelopt.torch.quantization import apply_sage_attention
880+
881+
apply_sage_attention(pipe.transformer)
900882
run_inference(pipe, args, label=args.kernel)
901883

902884
else:
903-
enable_attention_kernel(args.kernel)
904-
run_inference(pipe, args, label=args.kernel)
905-
print_kernel_stats()
906-
disable_attention_kernel()
885+
if args.quantize_p:
886+
# Standalone SageAttention (NVFP4 P-matrix) without sparse attention
887+
from modelopt.torch.quantization import apply_sage_attention
888+
889+
apply_sage_attention(pipe.transformer)
890+
run_inference(pipe, args, label="sage_attention")
891+
else:
892+
enable_attention_kernel(args.kernel)
893+
run_inference(pipe, args, label=args.kernel)
894+
print_kernel_stats()
895+
disable_attention_kernel()
907896

908897

909898
if __name__ == "__main__":

modelopt/torch/quantization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@
2424
from .conversion import *
2525
from .model_quant import *
2626
from .nn.modules.quant_module import QuantModuleRegistry
27+
from .sage_attention import apply_sage_attention
2728
from .utils import update_quant_cfg_with_kv_cache_quant

0 commit comments

Comments
 (0)