3636
3737``triton-sparse`` (requires triton + modelopt)
3838 ModelOpt Triton flash-attention kernel with N:M sparse softmax (2:4 by default).
39- Applied via ``mtsa.sparsify()`` to the WAN transformer using the ``diffusers_triton``
40- backend. For every 4 K positions, keeps top-2 attention scores; the other 2 are
41- set to -inf before softmax. Uses WanSparseAttentionModule from modelopt.
39+ Applied via ``mtsa.sparsify()`` to the WAN transformer using the ``triton``
40+ backend with the modelopt_triton diffusers attention backend. For every 4 K
41+ positions, keeps top-2 attention scores; the other 2 are set to -inf before
42+ softmax.
4243
4344``triton-skip`` (requires triton + modelopt)
4445 ModelOpt Triton flash-attention kernel with skip-softmax tile pruning.
4546 Tiles whose attention mass is below a threshold (default 0.1) are skipped entirely.
46- Applied via ``mtsa.sparsify()`` using the ``diffusers_triton`` backend.
47+ Applied via ``mtsa.sparsify()`` using the ``triton`` backend with the modelopt_triton
48+ diffusers attention backend.
4749
48- ``triton-sparse-nvfp4`` (requires triton + modelopt)
49- ModelOpt Triton flash-attention with N:M sparse softmax (2:4) AND NVFP4 E2M1
50- P-matrix quantization in a single fused Triton kernel pass. Per-tile scaling
51- (one scale per BLOCK_M×BLOCK_N tile) — finer granularity than a Python
52- post-softmax approach. Combines sparsity and quantization in one pass.
53-
54- ``triton-skip-nvfp4`` (requires triton + modelopt)
55- ModelOpt Triton flash-attention with skip-softmax tile pruning AND NVFP4 E2M1
56- P-matrix quantization in a single fused Triton kernel pass. Skipped tiles
57- contribute nothing and are never quantized.
50+ NVFP4 P-matrix quantization (``--quantize-p``) is a **SageAttention** feature —
51+ an independent quantization pass applied via
52+ ``modelopt.torch.quantization.apply_sage_attention()``. It quantizes the
53+ post-softmax P tile to NVFP4 E2M1 inside the Triton kernel (per-tile max scaling).
54+ ``--quantize-p`` can be combined with any Triton sparse kernel or used standalone.
5855
5956Requirements::
6057
8481 python wan2_sage_attention.py --prompt "..." --kernel triton-skip
8582
8683 # ModelOpt Triton sparse + NVFP4 P-matrix quantization
87- python wan2_sage_attention.py --prompt "..." --kernel triton-sparse-nvfp4
84+ python wan2_sage_attention.py --prompt "..." --kernel triton-sparse --quantize-p
8885
8986 # ModelOpt Triton skip-softmax + NVFP4 P-matrix quantization
90- python wan2_sage_attention.py --prompt "..." --kernel triton-skip-nvfp4
87+ python wan2_sage_attention.py --prompt "..." --kernel triton-skip --quantize-p
9188
9289 # Smaller 5B model (fits on a single 24 GB GPU)
9390 python wan2_sage_attention.py \\
118115KERNEL_SAGE2_FP8 = "sage2-fp8"
119116KERNEL_TRITON_SPARSE = "triton-sparse"
120117KERNEL_TRITON_SKIP = "triton-skip"
121- KERNEL_TRITON_SPARSE_NVFP4 = "triton-sparse-nvfp4"
122- KERNEL_TRITON_SKIP_NVFP4 = "triton-skip-nvfp4"
123118KERNEL_CHOICES = [
124119 KERNEL_FP8 ,
125120 KERNEL_SAGE1 ,
126121 KERNEL_SAGE2_FP16 ,
127122 KERNEL_SAGE2_FP8 ,
128123 KERNEL_TRITON_SPARSE ,
129124 KERNEL_TRITON_SKIP ,
130- KERNEL_TRITON_SPARSE_NVFP4 ,
131- KERNEL_TRITON_SKIP_NVFP4 ,
132125]
133126
134127# Kernels that modify pipe.transformer in-place via ModelOpt APIs (not SDPA patching).
135128_TRITON_MODELOPT_KERNELS = {
136129 KERNEL_TRITON_SPARSE ,
137130 KERNEL_TRITON_SKIP ,
138- KERNEL_TRITON_SPARSE_NVFP4 ,
139- KERNEL_TRITON_SKIP_NVFP4 ,
140131}
141132
142133_KERNEL_DESCRIPTIONS = {
146137 KERNEL_SAGE2_FP8 : "sageattn_qk_int8_pv_fp8_cuda (SA2++, INT8 QK + FP8 PV, fp32+fp16 accum)" ,
147138 KERNEL_TRITON_SPARSE : "ModelOpt Triton flash-attn + N:M sparse softmax (2:4) via mtsa.sparsify()" ,
148139 KERNEL_TRITON_SKIP : "ModelOpt Triton flash-attn + skip-softmax tile pruning via mtsa.sparsify()" ,
149- KERNEL_TRITON_SPARSE_NVFP4 : "ModelOpt Triton flash-attn + 2:4 sparse softmax + NVFP4 P-matrix quantization" ,
150- KERNEL_TRITON_SKIP_NVFP4 : "ModelOpt Triton flash-attn + skip-softmax tile pruning + NVFP4 P-matrix quantization" ,
151140}
152141
153142# SageAttention CUDA kernel support by GPU compute capability:
@@ -222,8 +211,6 @@ def _detect_available_kernels() -> list[str]:
222211
223212 available .append (KERNEL_TRITON_SPARSE )
224213 available .append (KERNEL_TRITON_SKIP )
225- available .append (KERNEL_TRITON_SPARSE_NVFP4 )
226- available .append (KERNEL_TRITON_SKIP_NVFP4 )
227214 except ImportError :
228215 pass
229216
@@ -443,7 +430,7 @@ def attention_kernel_ctx(kernel: str = KERNEL_FP8):
443430 "sparsity_m" : 4 ,
444431 "num_sink_tokens" : 0 ,
445432 "dense_window_size" : 0 ,
446- "backend" : "diffusers_triton " ,
433+ "backend" : "triton " ,
447434 "enable" : True ,
448435 },
449436 "default" : {"enable" : False },
@@ -457,36 +444,7 @@ def attention_kernel_ctx(kernel: str = KERNEL_FP8):
457444 "*" : {
458445 "method" : "triton_skip_softmax" ,
459446 "skip_softmax_threshold" : _TRITON_SKIP_DEFAULT_THRESHOLD ,
460- "backend" : "diffusers_triton" ,
461- "enable" : True ,
462- },
463- "default" : {"enable" : False },
464- }
465- }
466-
467- _TRITON_SPARSE_NVFP4_CONFIG = {
468- "sparse_cfg" : {
469- "*" : {
470- "method" : "triton_sparse_softmax" ,
471- "sparsity_n" : 2 ,
472- "sparsity_m" : 4 ,
473- "num_sink_tokens" : 0 ,
474- "dense_window_size" : 0 ,
475- "backend" : "diffusers_triton" ,
476- "quantize_p" : True ,
477- "enable" : True ,
478- },
479- "default" : {"enable" : False },
480- }
481- }
482-
483- _TRITON_SKIP_NVFP4_CONFIG = {
484- "sparse_cfg" : {
485- "*" : {
486- "method" : "triton_skip_softmax" ,
487- "skip_softmax_threshold" : _TRITON_SKIP_DEFAULT_THRESHOLD ,
488- "backend" : "diffusers_triton" ,
489- "quantize_p" : True ,
447+ "backend" : "triton" ,
490448 "enable" : True ,
491449 },
492450 "default" : {"enable" : False },
@@ -496,8 +454,6 @@ def attention_kernel_ctx(kernel: str = KERNEL_FP8):
496454_TRITON_KERNEL_CONFIGS = {
497455 KERNEL_TRITON_SPARSE : _TRITON_SPARSE_CONFIG ,
498456 KERNEL_TRITON_SKIP : _TRITON_SKIP_CONFIG ,
499- KERNEL_TRITON_SPARSE_NVFP4 : _TRITON_SPARSE_NVFP4_CONFIG ,
500- KERNEL_TRITON_SKIP_NVFP4 : _TRITON_SKIP_NVFP4_CONFIG ,
501457}
502458
503459
@@ -508,12 +464,12 @@ def apply_triton_sparse_kernel(
508464) -> None :
509465 """Apply a ModelOpt Triton sparse attention kernel to the WAN transformer.
510466
511- Calls ``mtsa.sparsify()`` with ``backend="diffusers_triton"``, which installs
512- a ``ModelOptWanAttnProcessor`` on every ``WanAttention`` module. The NVFP4
513- variants additionally pass ``quantize_p=True`` to the Triton kernel, enabling
514- per-tile NVFP4 E2M1 P-matrix quantization in a single fused pass.
467+ Calls ``mtsa.sparsify()`` with the ``triton`` backend, which activates the
468+ modelopt_triton diffusers attention backend for every attention forward pass.
515469
516- This modifies the model in-place.
470+ This modifies the model in-place. To additionally apply NVFP4 P-matrix
471+ quantization (SageAttention), call ``apply_sage_attention(transformer)``
472+ **after** this function.
517473
518474 Args:
519475 transformer: The ``pipe.transformer`` WAN model.
@@ -527,13 +483,13 @@ def apply_triton_sparse_kernel(
527483 import modelopt .torch .sparsity .attention_sparsity as mtsa
528484
529485 config = copy .deepcopy (_TRITON_KERNEL_CONFIGS [kernel ])
530- if skip_threshold is not None and kernel in ( KERNEL_TRITON_SKIP , KERNEL_TRITON_SKIP_NVFP4 ) :
486+ if skip_threshold is not None and kernel == KERNEL_TRITON_SKIP :
531487 config ["sparse_cfg" ]["*" ]["skip_softmax_threshold" ] = skip_threshold
532488
533489 mtsa .sparsify (transformer , config )
534490 thr = config ["sparse_cfg" ].get ("*" , {}).get ("skip_softmax_threshold" , "n/a" )
535491 print (f"[Attention] Applied { kernel } : { _KERNEL_DESCRIPTIONS [kernel ]} " )
536- if kernel in ( KERNEL_TRITON_SKIP , KERNEL_TRITON_SKIP_NVFP4 ) :
492+ if kernel == KERNEL_TRITON_SKIP :
537493 print (f"[Attention] skip_softmax_threshold={ thr } " )
538494
539495
@@ -763,9 +719,19 @@ def parse_args() -> argparse.Namespace:
763719 "sage2-fp16: SA2 INT8+FP16; "
764720 "sage2-fp8: SA2++ INT8+FP8; "
765721 "triton-sparse: ModelOpt Triton 2:4 N:M sparse softmax (requires triton + modelopt); "
766- "triton-skip: ModelOpt Triton skip-softmax tile pruning (requires triton + modelopt); "
767- "triton-sparse-nvfp4: triton-sparse + NVFP4 P-matrix quantization in one Triton pass; "
768- "triton-skip-nvfp4: triton-skip + NVFP4 P-matrix quantization in one Triton pass"
722+ "triton-skip: ModelOpt Triton skip-softmax tile pruning (requires triton + modelopt)"
723+ ),
724+ )
725+ parser .add_argument (
726+ "--quantize-p" ,
727+ action = "store_true" ,
728+ default = False ,
729+ help = (
730+ "Apply SageAttention NVFP4 E2M1 P-matrix quantization via "
731+ "modelopt.torch.quantization.apply_sage_attention(). "
732+ "Quantizes the post-softmax P tile inside the Triton kernel (per-tile max scaling). "
733+ "Can be used standalone or combined with any Triton sparse kernel: "
734+ "--kernel triton-sparse --quantize-p"
769735 ),
770736 )
771737 parser .add_argument (
@@ -792,7 +758,7 @@ def parse_args() -> argparse.Namespace:
792758 default = None ,
793759 metavar = "LAMBDA" ,
794760 help = (
795- "Override skip_softmax_threshold for triton-skip / triton-skip-nvfp4 kernels . "
761+ "Override skip_softmax_threshold for the triton-skip kernel . "
796762 f"Default: { _TRITON_SKIP_DEFAULT_THRESHOLD } . "
797763 "A tile is skipped when exp(tile_max - running_max) < LAMBDA "
798764 "(equivalently: tile_max < running_max + log(LAMBDA)). "
@@ -822,7 +788,15 @@ def main() -> None:
822788
823789 # --- Quantized ---
824790 if args .kernel in _TRITON_MODELOPT_KERNELS :
825- apply_triton_sparse_kernel (pipe .transformer , args .kernel , skip_threshold = args .skip_threshold )
791+ apply_triton_sparse_kernel (
792+ pipe .transformer ,
793+ args .kernel ,
794+ skip_threshold = args .skip_threshold ,
795+ )
796+ if args .quantize_p :
797+ from modelopt .torch .quantization import apply_sage_attention
798+
799+ apply_sage_attention (pipe .transformer )
826800 else :
827801 enable_attention_kernel (args .kernel )
828802 _ , frames_quant = run_inference (pipe , args , label = args .kernel )
@@ -896,14 +870,29 @@ def main() -> None:
896870 run_inference (pipe , args , label = "baseline" )
897871
898872 elif args .kernel in _TRITON_MODELOPT_KERNELS :
899- apply_triton_sparse_kernel (pipe .transformer , args .kernel , skip_threshold = args .skip_threshold )
873+ apply_triton_sparse_kernel (
874+ pipe .transformer ,
875+ args .kernel ,
876+ skip_threshold = args .skip_threshold ,
877+ )
878+ if args .quantize_p :
879+ from modelopt .torch .quantization import apply_sage_attention
880+
881+ apply_sage_attention (pipe .transformer )
900882 run_inference (pipe , args , label = args .kernel )
901883
902884 else :
903- enable_attention_kernel (args .kernel )
904- run_inference (pipe , args , label = args .kernel )
905- print_kernel_stats ()
906- disable_attention_kernel ()
885+ if args .quantize_p :
886+ # Standalone SageAttention (NVFP4 P-matrix) without sparse attention
887+ from modelopt .torch .quantization import apply_sage_attention
888+
889+ apply_sage_attention (pipe .transformer )
890+ run_inference (pipe , args , label = "sage_attention" )
891+ else :
892+ enable_attention_kernel (args .kernel )
893+ run_inference (pipe , args , label = args .kernel )
894+ print_kernel_stats ()
895+ disable_attention_kernel ()
907896
908897
909898if __name__ == "__main__" :
0 commit comments