@@ -72,8 +72,8 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
7272 title = "Backend implementation." ,
7373 description = (
7474 "Backend to use for sparse attention computation. "
75- "Only 'pytorch' is supported, which uses softmax patching with F.softmax. "
76- "Requires model to be loaded with attn_implementation='eager' ."
75+ "'pytorch' uses softmax patching with F.softmax (requires attn_implementation='eager') . "
76+ "'triton' uses the fused Triton kernel (requires attn_implementation='modelopt_triton') ."
7777 ),
7878 )
7979
@@ -89,10 +89,20 @@ class SparseAttentionAttributeConfig(ModeloptBaseConfig):
8989 description = (
9090 "Whether the model uses causal (autoregressive) attention. "
9191 "If True, sparsity statistics are calculated over the lower triangle only. "
92+ "Set to False for cross-attention models. "
9293 "Defaults to True for decoder-only models like GPT, LLaMA, etc."
9394 ),
9495 )
9596
97+ skip_diagonal_blocks : bool = ModeloptField (
98+ default = True ,
99+ title = "Skip diagonal blocks." ,
100+ description = (
101+ "When True, keep diagonal tiles dense for 2:4 sparse attention. "
102+ "Only used by sparse24_triton method. Defaults to True."
103+ ),
104+ )
105+
96106 @field_validator ("method" )
97107 @classmethod
98108 def validate_method (cls , v ):
@@ -104,11 +114,12 @@ def validate_method(cls, v):
104114 @field_validator ("backend" )
105115 @classmethod
106116 def validate_backend (cls , v ):
107- """Validate backend is pytorch."""
108- if v != "pytorch" :
117+ """Validate backend is pytorch or triton ."""
118+ if v not in ( "pytorch" , "triton" ) :
109119 raise ValueError (
110- f"Invalid backend: { v } . Only 'pytorch' backend is supported. "
111- f"Model must be loaded with attn_implementation='eager'."
120+ f"Invalid backend: { v } . Supported backends: 'pytorch' (requires "
121+ f"attn_implementation='eager'), 'triton' (requires "
122+ f"attn_implementation='modelopt_triton')."
112123 )
113124 return v
114125
@@ -416,10 +427,24 @@ class FlashSkipSoftmaxConfig(SparseAttentionConfig):
416427 },
417428}
418429
430+ # 2:4 structured sparsity via Triton prefill kernel (prefill-only)
431+ SPARSE24_TRITON = {
432+ "sparse_cfg" : {
433+ "*attn*" : {
434+ "method" : "sparse24_triton" ,
435+ "backend" : "triton" ,
436+ "skip_diagonal_blocks" : True ,
437+ "enable" : True ,
438+ },
439+ "default" : {"enable" : False },
440+ },
441+ }
442+
419443
420444__all__ = [
421445 "SKIP_SOFTMAX_CALIB" ,
422446 "SKIP_SOFTMAX_DEFAULT" ,
447+ "SPARSE24_TRITON" ,
423448 "CalibrationConfig" ,
424449 "FlashSkipSoftmaxConfig" ,
425450 "SparseAttentionAttributeConfig" ,
0 commit comments