From d9f00fc71790d9854136f86fc92edbcbcb0b4e90 Mon Sep 17 00:00:00 2001 From: KyleAtASU Date: Wed, 25 Mar 2026 18:30:12 -0700 Subject: [PATCH 1/3] fix: add gate_compress parameter to all attention backends (#817) --- fastvideo/attention/backends/abstract.py | 6 ++++-- fastvideo/attention/backends/flash_attn.py | 3 ++- fastvideo/attention/backends/sage_attn.py | 3 ++- fastvideo/attention/backends/sage_attn3.py | 3 ++- fastvideo/attention/backends/sdpa.py | 3 ++- fastvideo/attention/backends/sla.py | 6 ++++-- fastvideo/attention/backends/video_sparse_attn.py | 2 +- fastvideo/attention/backends/vmoba.py | 3 ++- 8 files changed, 19 insertions(+), 10 deletions(-) diff --git a/fastvideo/attention/backends/abstract.py b/fastvideo/attention/backends/abstract.py index a244f2a82d..12e8956d58 100644 --- a/fastvideo/attention/backends/abstract.py +++ b/fastvideo/attention/backends/abstract.py @@ -101,7 +101,8 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, + attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: ... @@ -168,6 +169,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_metadata: T, + gate_compress: torch.Tensor | None = None, + attn_metadata: T | None = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/fastvideo/attention/backends/flash_attn.py b/fastvideo/attention/backends/flash_attn.py index 389a0f5aaa..564e837ef5 100644 --- a/fastvideo/attention/backends/flash_attn.py +++ b/fastvideo/attention/backends/flash_attn.py @@ -99,7 +99,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_metadata: FlashAttnMetadata, + gate_compress: torch.Tensor | None = None, + attn_metadata: FlashAttnMetadata | None = None, ): def _key_padding_mask_from_attn_mask(attn_mask: torch.Tensor, key_len: int) -> torch.Tensor: diff --git a/fastvideo/attention/backends/sage_attn.py b/fastvideo/attention/backends/sage_attn.py index 5db61f6ac0..4c4dbc5e8b 100644 --- a/fastvideo/attention/backends/sage_attn.py +++ b/fastvideo/attention/backends/sage_attn.py @@ -52,7 +52,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, + attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: output = sageattn( query, diff --git a/fastvideo/attention/backends/sage_attn3.py b/fastvideo/attention/backends/sage_attn3.py index b2d3379bb7..e0b2ad17f9 100644 --- a/fastvideo/attention/backends/sage_attn3.py +++ b/fastvideo/attention/backends/sage_attn3.py @@ -60,7 +60,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, + attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) diff --git a/fastvideo/attention/backends/sdpa.py b/fastvideo/attention/backends/sdpa.py index 5d9b21591d..b8cedde134 100644 --- a/fastvideo/attention/backends/sdpa.py +++ b/fastvideo/attention/backends/sdpa.py @@ -73,7 +73,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_metadata: SDPAMetadata, + gate_compress: torch.Tensor | None = None, + attn_metadata: SDPAMetadata | None = None, ) -> torch.Tensor: # transpose to bs, heads, seq_len, head_dim query = query.transpose(1, 2) diff --git a/fastvideo/attention/backends/sla.py b/fastvideo/attention/backends/sla.py index d3a6f37176..4ce43be8ff 100644 --- a/fastvideo/attention/backends/sla.py +++ b/fastvideo/attention/backends/sla.py @@ -268,7 +268,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, + attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: """Forward pass for SLA attention. @@ -462,7 +463,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, + attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: """Forward pass for SageSLA attention with quantized kernels. diff --git a/fastvideo/attention/backends/video_sparse_attn.py b/fastvideo/attention/backends/video_sparse_attn.py index 43960a9401..3dfa05d1eb 100644 --- a/fastvideo/attention/backends/video_sparse_attn.py +++ b/fastvideo/attention/backends/video_sparse_attn.py @@ -232,7 +232,7 @@ def postprocess_output( ) -> torch.Tensor: return self.untile(output, attn_metadata.reverse_tile_partition_indices, attn_metadata.non_pad_index) - def forward( # type: ignore[override] + def forward( self, query: torch.Tensor, key: torch.Tensor, diff --git a/fastvideo/attention/backends/vmoba.py b/fastvideo/attention/backends/vmoba.py index eaa618968e..64256b9a9d 100644 --- a/fastvideo/attention/backends/vmoba.py +++ b/fastvideo/attention/backends/vmoba.py @@ -142,7 +142,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, + attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: """ query: [B, L, H, D] From 62598a40e3d38ae2318cd20b670697e07c659ef5 Mon Sep 17 00:00:00 2001 From: KyleAtASU Date: Wed, 25 Mar 2026 19:14:08 -0700 Subject: [PATCH 2/3] fix: add None checks for attn_metadata in sla.py and vmoba.py --- fastvideo/attention/backends/sla.py | 6 +++--- fastvideo/attention/backends/vmoba.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fastvideo/attention/backends/sla.py b/fastvideo/attention/backends/sla.py index 4ce43be8ff..5415d76d5e 100644 --- a/fastvideo/attention/backends/sla.py +++ b/fastvideo/attention/backends/sla.py @@ -294,8 +294,8 @@ def forward( # Get topk ratio from metadata if available topk_ratio = self.topk_ratio - if hasattr(attn_metadata, 'topk_ratio'): - topk_ratio = attn_metadata.topk_ratio # type: ignore[union-attr] + if attn_metadata is not None and hasattr(attn_metadata, 'topk_ratio'): + topk_ratio = attn_metadata.topk_ratio # Compute block-sparse attention pattern sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=topk_ratio, BLKQ=self.BLKQ, BLKK=self.BLKK) @@ -488,7 +488,7 @@ def forward( # Get topk ratio from metadata if available topk_ratio = self.topk_ratio - if hasattr(attn_metadata, 'topk_ratio'): + if attn_metadata is not None and hasattr(attn_metadata, 'topk_ratio'): topk_ratio = attn_metadata.topk_ratio # type: ignore[union-attr] # Determine block sizes based on GPU architecture diff --git a/fastvideo/attention/backends/vmoba.py b/fastvideo/attention/backends/vmoba.py index 64256b9a9d..c653eff58a 100644 --- a/fastvideo/attention/backends/vmoba.py +++ b/fastvideo/attention/backends/vmoba.py @@ -151,6 +151,8 @@ def forward( value: [B, L, H, D] attn_metadata: AttentionMetadata """ + if attn_metadata is None: + raise ValueError("VMOBAAttentionImpl requires attn_metadata to be provided.") batch_size, sequence_length, num_heads, head_dim = query.shape # select chunk type according to layer idx: From 8c0dacf972f9c894123216efff354e368cb42917 Mon Sep 17 00:00:00 2001 From: KyleAtASU Date: Wed, 25 Mar 2026 19:40:24 -0700 Subject: [PATCH 3/3] fix: move gate_compress after attn_metadata in all backend forward signatures --- fastvideo/attention/backends/abstract.py | 5 +++-- fastvideo/attention/backends/flash_attn.py | 3 ++- fastvideo/attention/backends/sage_attn.py | 2 +- fastvideo/attention/backends/sage_attn3.py | 3 ++- fastvideo/attention/backends/sdpa.py | 2 +- fastvideo/attention/backends/sla.py | 10 +++++----- fastvideo/attention/backends/video_sparse_attn.py | 2 +- fastvideo/attention/backends/vmoba.py | 5 ++--- 8 files changed, 17 insertions(+), 15 deletions(-) diff --git a/fastvideo/attention/backends/abstract.py b/fastvideo/attention/backends/abstract.py index 12e8956d58..165b49ae38 100644 --- a/fastvideo/attention/backends/abstract.py +++ b/fastvideo/attention/backends/abstract.py @@ -101,8 +101,9 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, gate_compress: torch.Tensor | None = None, - attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: ... @@ -169,7 +170,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_metadata: T, gate_compress: torch.Tensor | None = None, - attn_metadata: T | None = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/fastvideo/attention/backends/flash_attn.py b/fastvideo/attention/backends/flash_attn.py index 564e837ef5..1eb1fbef66 100644 --- a/fastvideo/attention/backends/flash_attn.py +++ b/fastvideo/attention/backends/flash_attn.py @@ -99,8 +99,9 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_metadata: FlashAttnMetadata, gate_compress: torch.Tensor | None = None, - attn_metadata: FlashAttnMetadata | None = None, + ): def _key_padding_mask_from_attn_mask(attn_mask: torch.Tensor, key_len: int) -> torch.Tensor: diff --git a/fastvideo/attention/backends/sage_attn.py b/fastvideo/attention/backends/sage_attn.py index 4c4dbc5e8b..ed17c94693 100644 --- a/fastvideo/attention/backends/sage_attn.py +++ b/fastvideo/attention/backends/sage_attn.py @@ -52,8 +52,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_metadata: AttentionMetadata, gate_compress: torch.Tensor | None = None, - attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: output = sageattn( query, diff --git a/fastvideo/attention/backends/sage_attn3.py b/fastvideo/attention/backends/sage_attn3.py index e0b2ad17f9..4b2a3df926 100644 --- a/fastvideo/attention/backends/sage_attn3.py +++ b/fastvideo/attention/backends/sage_attn3.py @@ -60,8 +60,9 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_metadata: AttentionMetadata, gate_compress: torch.Tensor | None = None, - attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) diff --git a/fastvideo/attention/backends/sdpa.py b/fastvideo/attention/backends/sdpa.py index b8cedde134..da2d29fad4 100644 --- a/fastvideo/attention/backends/sdpa.py +++ b/fastvideo/attention/backends/sdpa.py @@ -73,8 +73,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_metadata: SDPAMetadata, gate_compress: torch.Tensor | None = None, - attn_metadata: SDPAMetadata | None = None, ) -> torch.Tensor: # transpose to bs, heads, seq_len, head_dim query = query.transpose(1, 2) diff --git a/fastvideo/attention/backends/sla.py b/fastvideo/attention/backends/sla.py index 5415d76d5e..f6b87a5628 100644 --- a/fastvideo/attention/backends/sla.py +++ b/fastvideo/attention/backends/sla.py @@ -268,8 +268,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_metadata: AttentionMetadata, gate_compress: torch.Tensor | None = None, - attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: """Forward pass for SLA attention. @@ -294,8 +294,8 @@ def forward( # Get topk ratio from metadata if available topk_ratio = self.topk_ratio - if attn_metadata is not None and hasattr(attn_metadata, 'topk_ratio'): - topk_ratio = attn_metadata.topk_ratio + if hasattr(attn_metadata, 'topk_ratio'): + topk_ratio = attn_metadata.topk_ratio # type: ignore[union-attr] # Compute block-sparse attention pattern sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=topk_ratio, BLKQ=self.BLKQ, BLKK=self.BLKK) @@ -463,8 +463,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_metadata: AttentionMetadata, gate_compress: torch.Tensor | None = None, - attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: """Forward pass for SageSLA attention with quantized kernels. @@ -488,7 +488,7 @@ def forward( # Get topk ratio from metadata if available topk_ratio = self.topk_ratio - if attn_metadata is not None and hasattr(attn_metadata, 'topk_ratio'): + if hasattr(attn_metadata, 'topk_ratio'): topk_ratio = attn_metadata.topk_ratio # type: ignore[union-attr] # Determine block sizes based on GPU architecture diff --git a/fastvideo/attention/backends/video_sparse_attn.py b/fastvideo/attention/backends/video_sparse_attn.py index 3dfa05d1eb..247ef3b480 100644 --- a/fastvideo/attention/backends/video_sparse_attn.py +++ b/fastvideo/attention/backends/video_sparse_attn.py @@ -237,8 +237,8 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - gate_compress: torch.Tensor, attn_metadata: VideoSparseAttentionMetadata, + gate_compress: torch.Tensor, ) -> torch.Tensor: query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() diff --git a/fastvideo/attention/backends/vmoba.py b/fastvideo/attention/backends/vmoba.py index c653eff58a..b49eb66fa5 100644 --- a/fastvideo/attention/backends/vmoba.py +++ b/fastvideo/attention/backends/vmoba.py @@ -142,8 +142,9 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_metadata: AttentionMetadata, gate_compress: torch.Tensor | None = None, - attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: """ query: [B, L, H, D] @@ -151,8 +152,6 @@ def forward( value: [B, L, H, D] attn_metadata: AttentionMetadata """ - if attn_metadata is None: - raise ValueError("VMOBAAttentionImpl requires attn_metadata to be provided.") batch_size, sequence_length, num_heads, head_dim = query.shape # select chunk type according to layer idx: