@@ -281,6 +281,19 @@ class Kandinsky5AttnProcessor:
281281 def __init__ (self ):
282282 if not hasattr (F , "scaled_dot_product_attention" ):
283283 raise ImportError (f"{ self .__class__ .__name__ } requires PyTorch 2.0. Please upgrade your pytorch version." )
284+
285+ @torch .compile (mode = "max-autotune-no-cudagraphs" , dynamic = True )
286+ def compiled_flex_attn (self , query , key , value , attn_mask , backend , parallel_config ):
287+ hidden_states = dispatch_attention_fn (
288+ query ,
289+ key ,
290+ value ,
291+ attn_mask = attn_mask ,
292+ backend = backend ,
293+ parallel_config = parallel_config ,
294+ )
295+
296+ return hidden_states
284297
285298 def __call__ (self , attn , hidden_states , encoder_hidden_states = None , rotary_emb = None , sparse_params = None ):
286299 # query, key, value = self.get_qkv(x)
@@ -324,17 +337,28 @@ def apply_rotary(x, rope):
324337 sparse_params ["sta_mask" ],
325338 thr = sparse_params ["P" ],
326339 )
340+
341+ hidden_states = self .compiled_flex_attn (
342+ query ,
343+ key ,
344+ value ,
345+ attn_mask = attn_mask ,
346+ backend = self ._attention_backend ,
347+ parallel_config = self ._parallel_config
348+ )
349+
327350 else :
328351 attn_mask = None
352+
353+ hidden_states = dispatch_attention_fn (
354+ query ,
355+ key ,
356+ value ,
357+ attn_mask = attn_mask ,
358+ backend = self ._attention_backend ,
359+ parallel_config = self ._parallel_config ,
360+ )
329361
330- hidden_states = dispatch_attention_fn (
331- query ,
332- key ,
333- value ,
334- attn_mask = attn_mask ,
335- backend = self ._attention_backend ,
336- parallel_config = self ._parallel_config ,
337- )
338362 hidden_states = hidden_states .flatten (- 2 , - 1 )
339363
340364 attn_out = attn .out_layer (hidden_states )
0 commit comments