@@ -71,12 +71,14 @@ def _fwd(
7171 k : torch .Tensor ,
7272 v : torch .Tensor ,
7373 causal : bool ,
74+ seqused_k : Optional [torch .Tensor ] = None ,
7475 ) -> Tuple [torch .Tensor , torch .Tensor ]:
7576 """Calls _flash_attn_fwd with torch.compile disabled. Returns (output, lse)."""
7677 output , lse = _flash_attn_fwd (
7778 q ,
7879 k ,
7980 v ,
81+ seqused_k = seqused_k ,
8082 softmax_scale = self .scale ,
8183 causal = causal ,
8284 window_size_left = None ,
@@ -120,6 +122,7 @@ def forward(
120122 v : torch .Tensor ,
121123 * ,
122124 attention_mask : PredefinedAttentionMask = PredefinedAttentionMask .FULL ,
125+ key_padding_mask : Optional [torch .Tensor ] = None ,
123126 ** kwargs ,
124127 ) -> torch .Tensor :
125128 """
@@ -132,11 +135,21 @@ def forward(
132135 k: Key tensor [batch_size, seq_len_kv, num_kv_heads, head_dim]
133136 v: Value tensor [batch_size, seq_len_kv, num_kv_heads, head_dim]
134137 attention_mask: Attention mask type (CAUSAL or FULL)
138+ key_padding_mask: Optional ``[B, S_kv]`` bool tensor; True = valid,
139+ False = pad. Translated to FA4's ``seqused_k = mask.sum(dim=1)``
140+ (assumes True-prefix layout). Non-causal only.
135141
136142 Returns:
137143 Output tensor [batch_size, seq_len, num_heads, head_dim]
138144 """
139- output , _ = self .forward_with_lse (q , k , v , attention_mask = attention_mask , ** kwargs )
145+ output , _ = self .forward_with_lse (
146+ q ,
147+ k ,
148+ v ,
149+ attention_mask = attention_mask ,
150+ key_padding_mask = key_padding_mask ,
151+ ** kwargs ,
152+ )
140153 return output
141154
142155 def forward_with_lse (
@@ -145,6 +158,7 @@ def forward_with_lse(
145158 k : torch .Tensor ,
146159 v : torch .Tensor ,
147160 attention_mask : PredefinedAttentionMask = PredefinedAttentionMask .FULL ,
161+ key_padding_mask : Optional [torch .Tensor ] = None ,
148162 ** kwargs ,
149163 ) -> Tuple [torch .Tensor , torch .Tensor ]:
150164 """
@@ -157,7 +171,20 @@ def forward_with_lse(
157171 partial attention results in Attention2D parallelism.
158172 """
159173 q , k , v , is_causal , origin_dtype = self ._prepare_inputs (q , k , v , attention_mask )
160- output , lse = self ._fwd (q , k , v , is_causal )
174+ seqused_k = None
175+ if key_padding_mask is not None :
176+ assert not is_causal , "key_padding_mask is not supported with causal attention"
177+ assert key_padding_mask .dim () == 2 and key_padding_mask .shape == (
178+ q .shape [0 ],
179+ k .shape [1 ],
180+ ), (
181+ f"Invalid key_padding_mask shape: expected [B={ q .shape [0 ]} , "
182+ f"S_kv={ k .shape [1 ]} ], got { tuple (key_padding_mask .shape )} "
183+ )
184+ # FA4 seqused_k assumes a True-prefix layout: positions [0, valid)
185+ # are kept, [valid, S_kv) are masked. mask.sum gives the prefix length.
186+ seqused_k = key_padding_mask .sum (dim = 1 ).to (torch .int32 )
187+ output , lse = self ._fwd (q , k , v , is_causal , seqused_k = seqused_k )
161188 if output .dtype != origin_dtype :
162189 output = output .to (origin_dtype )
163190 return output , lse
0 commit comments