diff --git a/fastvideo/attention/backends/abstract.py b/fastvideo/attention/backends/abstract.py index a244f2a82d..165b49ae38 100644 --- a/fastvideo/attention/backends/abstract.py +++ b/fastvideo/attention/backends/abstract.py @@ -102,6 +102,8 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, + ) -> torch.Tensor: ... @@ -169,5 +171,6 @@ def forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: T, + gate_compress: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/fastvideo/attention/backends/flash_attn.py b/fastvideo/attention/backends/flash_attn.py index 389a0f5aaa..1eb1fbef66 100644 --- a/fastvideo/attention/backends/flash_attn.py +++ b/fastvideo/attention/backends/flash_attn.py @@ -100,6 +100,8 @@ def forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: FlashAttnMetadata, + gate_compress: torch.Tensor | None = None, + ): def _key_padding_mask_from_attn_mask(attn_mask: torch.Tensor, key_len: int) -> torch.Tensor: diff --git a/fastvideo/attention/backends/sage_attn.py b/fastvideo/attention/backends/sage_attn.py index 5db61f6ac0..ed17c94693 100644 --- a/fastvideo/attention/backends/sage_attn.py +++ b/fastvideo/attention/backends/sage_attn.py @@ -53,6 +53,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, ) -> torch.Tensor: output = sageattn( query, diff --git a/fastvideo/attention/backends/sage_attn3.py b/fastvideo/attention/backends/sage_attn3.py index b2d3379bb7..4b2a3df926 100644 --- a/fastvideo/attention/backends/sage_attn3.py +++ b/fastvideo/attention/backends/sage_attn3.py @@ -61,6 +61,8 @@ def forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, + ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) diff --git a/fastvideo/attention/backends/sdpa.py b/fastvideo/attention/backends/sdpa.py index 5d9b21591d..da2d29fad4 100644 --- a/fastvideo/attention/backends/sdpa.py +++ b/fastvideo/attention/backends/sdpa.py @@ -74,6 +74,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: SDPAMetadata, + gate_compress: torch.Tensor | None = None, ) -> torch.Tensor: # transpose to bs, heads, seq_len, head_dim query = query.transpose(1, 2) diff --git a/fastvideo/attention/backends/sla.py b/fastvideo/attention/backends/sla.py index d3a6f37176..f6b87a5628 100644 --- a/fastvideo/attention/backends/sla.py +++ b/fastvideo/attention/backends/sla.py @@ -269,6 +269,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass for SLA attention. @@ -463,6 +464,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass for SageSLA attention with quantized kernels. diff --git a/fastvideo/attention/backends/video_sparse_attn.py b/fastvideo/attention/backends/video_sparse_attn.py index 43960a9401..247ef3b480 100644 --- a/fastvideo/attention/backends/video_sparse_attn.py +++ b/fastvideo/attention/backends/video_sparse_attn.py @@ -232,13 +232,13 @@ def postprocess_output( ) -> torch.Tensor: return self.untile(output, attn_metadata.reverse_tile_partition_indices, attn_metadata.non_pad_index) - def forward( # type: ignore[override] + def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - gate_compress: torch.Tensor, attn_metadata: VideoSparseAttentionMetadata, + gate_compress: torch.Tensor, ) -> torch.Tensor: query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() diff --git a/fastvideo/attention/backends/vmoba.py b/fastvideo/attention/backends/vmoba.py index eaa618968e..b49eb66fa5 100644 --- a/fastvideo/attention/backends/vmoba.py +++ b/fastvideo/attention/backends/vmoba.py @@ -143,6 +143,8 @@ def forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: AttentionMetadata, + gate_compress: torch.Tensor | None = None, + ) -> torch.Tensor: """ query: [B, L, H, D]