Skip to content

Commit b35445c

Browse files
committed
add compiled Nabla Attention
1 parent 11200b4 commit b35445c

1 file changed

Lines changed: 32 additions & 8 deletions

File tree

src/diffusers/models/transformers/transformer_kandinsky.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,19 @@ class Kandinsky5AttnProcessor:
281281
def __init__(self):
282282
if not hasattr(F, "scaled_dot_product_attention"):
283283
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
284+
285+
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
286+
def compiled_flex_attn(self, query, key, value, attn_mask, backend, parallel_config):
287+
hidden_states = dispatch_attention_fn(
288+
query,
289+
key,
290+
value,
291+
attn_mask=attn_mask,
292+
backend=backend,
293+
parallel_config=parallel_config,
294+
)
295+
296+
return hidden_states
284297

285298
def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None):
286299
# query, key, value = self.get_qkv(x)
@@ -324,17 +337,28 @@ def apply_rotary(x, rope):
324337
sparse_params["sta_mask"],
325338
thr=sparse_params["P"],
326339
)
340+
341+
hidden_states = self.compiled_flex_attn(
342+
query,
343+
key,
344+
value,
345+
attn_mask=attn_mask,
346+
backend=self._attention_backend,
347+
parallel_config=self._parallel_config
348+
)
349+
327350
else:
328351
attn_mask = None
352+
353+
hidden_states = dispatch_attention_fn(
354+
query,
355+
key,
356+
value,
357+
attn_mask=attn_mask,
358+
backend=self._attention_backend,
359+
parallel_config=self._parallel_config,
360+
)
329361

330-
hidden_states = dispatch_attention_fn(
331-
query,
332-
key,
333-
value,
334-
attn_mask=attn_mask,
335-
backend=self._attention_backend,
336-
parallel_config=self._parallel_config,
337-
)
338362
hidden_states = hidden_states.flatten(-2, -1)
339363

340364
attn_out = attn.out_layer(hidden_states)

0 commit comments

Comments
 (0)