Skip to content

Commit 89285f1

Browse files
authored
Merge branch 'main' into refactor-hub-attn-kernels
2 parents 5d49e42 + 325a950 commit 89285f1

4 files changed

Lines changed: 117 additions & 19 deletions

File tree

docs/source/en/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,6 @@
529529
title: Kandinsky 2.2
530530
- local: api/pipelines/kandinsky3
531531
title: Kandinsky 3
532-
- local: api/pipelines/kandinsky5
533-
title: Kandinsky 5
534532
- local: api/pipelines/kolors
535533
title: Kolors
536534
- local: api/pipelines/latent_consistency_models
@@ -656,6 +654,8 @@
656654
title: Text2Video-Zero
657655
- local: api/pipelines/wan
658656
title: Wan
657+
- local: api/pipelines/kandinsky5_video
658+
title: Kandinsky 5.0 Video
659659
title: Video
660660
title: Pipelines
661661
- sections:

docs/source/en/api/pipelines/kandinsky5.md renamed to docs/source/en/api/pipelines/kandinsky5_video.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
77
specific language governing permissions and limitations under the License.
88
-->
99

10-
# Kandinsky 5.0
10+
# Kandinsky 5.0 Video
1111

12-
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
12+
Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
1313

1414

1515
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
@@ -92,7 +92,7 @@ pipe = pipe.to("cuda")
9292

9393
pipe.transformer.set_attention_backend(
9494
"flex"
95-
) # <--- Set attention backend to Flex
95+
) # <--- Sett attention bakend to Flex
9696
pipe.transformer.compile(
9797
mode="max-autotune-no-cudagraphs",
9898
dynamic=True
@@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9)
115115
```
116116

117117
### Diffusion Distilled model
118-
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
118+
**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
119119

120120
```python
121121
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"

src/diffusers/models/attention_dispatch.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,86 @@ def _(
675675
# ===== Helper functions to use attention backends with templated CP autograd functions =====
676676

677677

678+
def _native_attention_forward_op(
679+
ctx: torch.autograd.function.FunctionCtx,
680+
query: torch.Tensor,
681+
key: torch.Tensor,
682+
value: torch.Tensor,
683+
attn_mask: Optional[torch.Tensor] = None,
684+
dropout_p: float = 0.0,
685+
is_causal: bool = False,
686+
scale: Optional[float] = None,
687+
enable_gqa: bool = False,
688+
return_lse: bool = False,
689+
_save_ctx: bool = True,
690+
_parallel_config: Optional["ParallelConfig"] = None,
691+
):
692+
# Native attention does not return_lse
693+
if return_lse:
694+
raise ValueError("Native attention does not support return_lse=True")
695+
696+
# used for backward pass
697+
if _save_ctx:
698+
ctx.save_for_backward(query, key, value)
699+
ctx.attn_mask = attn_mask
700+
ctx.dropout_p = dropout_p
701+
ctx.is_causal = is_causal
702+
ctx.scale = scale
703+
ctx.enable_gqa = enable_gqa
704+
705+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
706+
out = torch.nn.functional.scaled_dot_product_attention(
707+
query=query,
708+
key=key,
709+
value=value,
710+
attn_mask=attn_mask,
711+
dropout_p=dropout_p,
712+
is_causal=is_causal,
713+
scale=scale,
714+
enable_gqa=enable_gqa,
715+
)
716+
out = out.permute(0, 2, 1, 3)
717+
718+
return out
719+
720+
721+
def _native_attention_backward_op(
722+
ctx: torch.autograd.function.FunctionCtx,
723+
grad_out: torch.Tensor,
724+
*args,
725+
**kwargs,
726+
):
727+
query, key, value = ctx.saved_tensors
728+
729+
query.requires_grad_(True)
730+
key.requires_grad_(True)
731+
value.requires_grad_(True)
732+
733+
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
734+
out = torch.nn.functional.scaled_dot_product_attention(
735+
query=query_t,
736+
key=key_t,
737+
value=value_t,
738+
attn_mask=ctx.attn_mask,
739+
dropout_p=ctx.dropout_p,
740+
is_causal=ctx.is_causal,
741+
scale=ctx.scale,
742+
enable_gqa=ctx.enable_gqa,
743+
)
744+
out = out.permute(0, 2, 1, 3)
745+
746+
grad_out_t = grad_out.permute(0, 2, 1, 3)
747+
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
748+
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
749+
)
750+
751+
grad_query = grad_query_t.permute(0, 2, 1, 3)
752+
grad_key = grad_key_t.permute(0, 2, 1, 3)
753+
grad_value = grad_value_t.permute(0, 2, 1, 3)
754+
755+
return grad_query, grad_key, grad_value
756+
757+
678758
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
679759
# forward declaration:
680760
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
@@ -1550,6 +1630,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
15501630
@_AttentionBackendRegistry.register(
15511631
AttentionBackendName.NATIVE,
15521632
constraints=[_check_device, _check_shape],
1633+
supports_context_parallel=True,
15531634
)
15541635
def _native_attention(
15551636
query: torch.Tensor,
@@ -1565,18 +1646,35 @@ def _native_attention(
15651646
) -> torch.Tensor:
15661647
if return_lse:
15671648
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
1568-
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
1569-
out = torch.nn.functional.scaled_dot_product_attention(
1570-
query=query,
1571-
key=key,
1572-
value=value,
1573-
attn_mask=attn_mask,
1574-
dropout_p=dropout_p,
1575-
is_causal=is_causal,
1576-
scale=scale,
1577-
enable_gqa=enable_gqa,
1578-
)
1579-
out = out.permute(0, 2, 1, 3)
1649+
if _parallel_config is None:
1650+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
1651+
out = torch.nn.functional.scaled_dot_product_attention(
1652+
query=query,
1653+
key=key,
1654+
value=value,
1655+
attn_mask=attn_mask,
1656+
dropout_p=dropout_p,
1657+
is_causal=is_causal,
1658+
scale=scale,
1659+
enable_gqa=enable_gqa,
1660+
)
1661+
out = out.permute(0, 2, 1, 3)
1662+
else:
1663+
out = _templated_context_parallel_attention(
1664+
query,
1665+
key,
1666+
value,
1667+
attn_mask,
1668+
dropout_p,
1669+
is_causal,
1670+
scale,
1671+
enable_gqa,
1672+
return_lse,
1673+
forward_op=_native_attention_forward_op,
1674+
backward_op=_native_attention_backward_op,
1675+
_parallel_config=_parallel_config,
1676+
)
1677+
15801678
return out
15811679

15821680

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def from_pretrained(
313313
]
314314
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
315315

316-
config = cls.load_config(pretrained_model_name_or_path)
316+
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
317317
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
318318
trust_remote_code = resolve_trust_remote_code(
319319
trust_remote_code, pretrained_model_name_or_path, has_remote_code

0 commit comments

Comments
 (0)