Skip to content

Commit 3dcc9ca

Browse files
authored
Merge branch 'main' into cp-fix
2 parents d65f857 + 325a950 commit 3dcc9ca

21 files changed

Lines changed: 291 additions & 61 deletions

docs/source/en/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,6 @@
529529
title: Kandinsky 2.2
530530
- local: api/pipelines/kandinsky3
531531
title: Kandinsky 3
532-
- local: api/pipelines/kandinsky5
533-
title: Kandinsky 5
534532
- local: api/pipelines/kolors
535533
title: Kolors
536534
- local: api/pipelines/latent_consistency_models
@@ -656,6 +654,8 @@
656654
title: Text2Video-Zero
657655
- local: api/pipelines/wan
658656
title: Wan
657+
- local: api/pipelines/kandinsky5_video
658+
title: Kandinsky 5.0 Video
659659
title: Video
660660
title: Pipelines
661661
- sections:

docs/source/en/api/pipelines/kandinsky5.md renamed to docs/source/en/api/pipelines/kandinsky5_video.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
77
specific language governing permissions and limitations under the License.
88
-->
99

10-
# Kandinsky 5.0
10+
# Kandinsky 5.0 Video
1111

12-
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
12+
Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
1313

1414

1515
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
@@ -92,7 +92,7 @@ pipe = pipe.to("cuda")
9292

9393
pipe.transformer.set_attention_backend(
9494
"flex"
95-
) # <--- Set attention backend to Flex
95+
) # <--- Sett attention bakend to Flex
9696
pipe.transformer.compile(
9797
mode="max-autotune-no-cudagraphs",
9898
dynamic=True
@@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9)
115115
```
116116

117117
### Diffusion Distilled model
118-
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
118+
**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
119119

120120
```python
121121
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"

src/diffusers/models/attention_dispatch.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,86 @@ def _(
640640
# ===== Helper functions to use attention backends with templated CP autograd functions =====
641641

642642

643+
def _native_attention_forward_op(
644+
ctx: torch.autograd.function.FunctionCtx,
645+
query: torch.Tensor,
646+
key: torch.Tensor,
647+
value: torch.Tensor,
648+
attn_mask: Optional[torch.Tensor] = None,
649+
dropout_p: float = 0.0,
650+
is_causal: bool = False,
651+
scale: Optional[float] = None,
652+
enable_gqa: bool = False,
653+
return_lse: bool = False,
654+
_save_ctx: bool = True,
655+
_parallel_config: Optional["ParallelConfig"] = None,
656+
):
657+
# Native attention does not return_lse
658+
if return_lse:
659+
raise ValueError("Native attention does not support return_lse=True")
660+
661+
# used for backward pass
662+
if _save_ctx:
663+
ctx.save_for_backward(query, key, value)
664+
ctx.attn_mask = attn_mask
665+
ctx.dropout_p = dropout_p
666+
ctx.is_causal = is_causal
667+
ctx.scale = scale
668+
ctx.enable_gqa = enable_gqa
669+
670+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
671+
out = torch.nn.functional.scaled_dot_product_attention(
672+
query=query,
673+
key=key,
674+
value=value,
675+
attn_mask=attn_mask,
676+
dropout_p=dropout_p,
677+
is_causal=is_causal,
678+
scale=scale,
679+
enable_gqa=enable_gqa,
680+
)
681+
out = out.permute(0, 2, 1, 3)
682+
683+
return out
684+
685+
686+
def _native_attention_backward_op(
687+
ctx: torch.autograd.function.FunctionCtx,
688+
grad_out: torch.Tensor,
689+
*args,
690+
**kwargs,
691+
):
692+
query, key, value = ctx.saved_tensors
693+
694+
query.requires_grad_(True)
695+
key.requires_grad_(True)
696+
value.requires_grad_(True)
697+
698+
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
699+
out = torch.nn.functional.scaled_dot_product_attention(
700+
query=query_t,
701+
key=key_t,
702+
value=value_t,
703+
attn_mask=ctx.attn_mask,
704+
dropout_p=ctx.dropout_p,
705+
is_causal=ctx.is_causal,
706+
scale=ctx.scale,
707+
enable_gqa=ctx.enable_gqa,
708+
)
709+
out = out.permute(0, 2, 1, 3)
710+
711+
grad_out_t = grad_out.permute(0, 2, 1, 3)
712+
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
713+
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
714+
)
715+
716+
grad_query = grad_query_t.permute(0, 2, 1, 3)
717+
grad_key = grad_key_t.permute(0, 2, 1, 3)
718+
grad_value = grad_value_t.permute(0, 2, 1, 3)
719+
720+
return grad_query, grad_key, grad_value
721+
722+
643723
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
644724
# forward declaration:
645725
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
@@ -1514,6 +1594,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
15141594
@_AttentionBackendRegistry.register(
15151595
AttentionBackendName.NATIVE,
15161596
constraints=[_check_device, _check_shape],
1597+
supports_context_parallel=True,
15171598
)
15181599
def _native_attention(
15191600
query: torch.Tensor,
@@ -1529,18 +1610,35 @@ def _native_attention(
15291610
) -> torch.Tensor:
15301611
if return_lse:
15311612
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
1532-
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
1533-
out = torch.nn.functional.scaled_dot_product_attention(
1534-
query=query,
1535-
key=key,
1536-
value=value,
1537-
attn_mask=attn_mask,
1538-
dropout_p=dropout_p,
1539-
is_causal=is_causal,
1540-
scale=scale,
1541-
enable_gqa=enable_gqa,
1542-
)
1543-
out = out.permute(0, 2, 1, 3)
1613+
if _parallel_config is None:
1614+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
1615+
out = torch.nn.functional.scaled_dot_product_attention(
1616+
query=query,
1617+
key=key,
1618+
value=value,
1619+
attn_mask=attn_mask,
1620+
dropout_p=dropout_p,
1621+
is_causal=is_causal,
1622+
scale=scale,
1623+
enable_gqa=enable_gqa,
1624+
)
1625+
out = out.permute(0, 2, 1, 3)
1626+
else:
1627+
out = _templated_context_parallel_attention(
1628+
query,
1629+
key,
1630+
value,
1631+
attn_mask,
1632+
dropout_p,
1633+
is_causal,
1634+
scale,
1635+
enable_gqa,
1636+
return_lse,
1637+
forward_op=_native_attention_forward_op,
1638+
backward_op=_native_attention_backward_op,
1639+
_parallel_config=_parallel_config,
1640+
)
1641+
15441642
return out
15451643

15461644

src/diffusers/models/auto_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,13 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
147147
"force_download",
148148
"local_files_only",
149149
"proxies",
150-
"resume_download",
151150
"revision",
152151
"token",
153152
]
154153
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
155154

156155
# load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
157-
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
156+
load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder"]}
158157

159158
library = None
160159
orig_class_name = None
@@ -205,7 +204,6 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
205204
module_file=module_file,
206205
class_name=class_name,
207206
**hub_kwargs,
208-
**kwargs,
209207
)
210208
else:
211209
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates

src/diffusers/modular_pipelines/components_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,11 @@ def __call__(self, hooks, model_id, model, execution_device):
164164

165165
device_type = execution_device.type
166166
device_module = getattr(torch, device_type, torch.cuda)
167-
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
167+
try:
168+
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
169+
except AttributeError:
170+
raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.")
171+
168172
mem_on_device = mem_on_device - self.memory_reserve_margin
169173
if current_module_size < mem_on_device:
170174
return []
@@ -699,6 +703,8 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None,
699703
if not is_accelerate_available():
700704
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
701705

706+
# TODO: add a warning if mem_get_info isn't available on `device`.
707+
702708
for name, component in self.components.items():
703709
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
704710
remove_hook_from_module(component, recurse=True)

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
598598
and getattr(block_state, "image_width", None) is not None
599599
):
600600
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
601-
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
601+
image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
602602
img_ids = FluxPipeline._prepare_latent_image_ids(
603603
None, image_latent_height // 2, image_latent_width // 2, device, dtype
604604
)

src/diffusers/modular_pipelines/flux/denoise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def inputs(self) -> List[Tuple[str, Any]]:
5959
),
6060
InputParam(
6161
"guidance",
62-
required=True,
62+
required=False,
6363
type_hint=torch.Tensor,
6464
description="Guidance scale as a tensor",
6565
),
@@ -141,7 +141,7 @@ def inputs(self) -> List[Tuple[str, Any]]:
141141
),
142142
InputParam(
143143
"guidance",
144-
required=True,
144+
required=False,
145145
type_hint=torch.Tensor,
146146
description="Guidance scale as a tensor",
147147
),

src/diffusers/modular_pipelines/flux/encoders.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def expected_components(self) -> List[ComponentSpec]:
9595
ComponentSpec(
9696
"image_processor",
9797
VaeImageProcessor,
98-
config=FrozenDict({"vae_scale_factor": 16}),
98+
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
9999
default_creation_method="from_config",
100100
),
101101
]
@@ -143,10 +143,6 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState):
143143
class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
144144
model_name = "flux-kontext"
145145

146-
def __init__(self, _auto_resize=True):
147-
self._auto_resize = _auto_resize
148-
super().__init__()
149-
150146
@property
151147
def description(self) -> str:
152148
return (
@@ -167,7 +163,7 @@ def expected_components(self) -> List[ComponentSpec]:
167163

168164
@property
169165
def inputs(self) -> List[InputParam]:
170-
return [InputParam("image")]
166+
return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)]
171167

172168
@property
173169
def intermediate_outputs(self) -> List[OutputParam]:
@@ -195,7 +191,8 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState):
195191
img = images[0]
196192
image_height, image_width = components.image_processor.get_default_height_width(img)
197193
aspect_ratio = image_width / image_height
198-
if self._auto_resize:
194+
_auto_resize = block_state._auto_resize
195+
if _auto_resize:
199196
# Kontext is trained on specific resolutions, using one of them is recommended
200197
_, image_width, image_height = min(
201198
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS

src/diffusers/modular_pipelines/flux/inputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
112112
block_state.prompt_embeds = block_state.prompt_embeds.view(
113113
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
114114
)
115+
pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt)
116+
block_state.pooled_prompt_embeds = pooled_prompt_embeds.view(
117+
block_state.batch_size * block_state.num_images_per_prompt, -1
118+
)
115119
self.set_block_state(state, block_state)
116120

117121
return components, state

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,15 @@ def from_pretrained(
305305
"cache_dir",
306306
"force_download",
307307
"local_files_only",
308+
"local_dir",
308309
"proxies",
309-
"resume_download",
310310
"revision",
311311
"subfolder",
312312
"token",
313313
]
314314
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
315315

316-
config = cls.load_config(pretrained_model_name_or_path)
316+
config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
317317
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
318318
trust_remote_code = resolve_trust_remote_code(
319319
trust_remote_code, pretrained_model_name_or_path, has_remote_code
@@ -331,7 +331,6 @@ def from_pretrained(
331331
module_file=module_file,
332332
class_name=class_name,
333333
**hub_kwargs,
334-
**kwargs,
335334
)
336335
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
337336
block_kwargs = {
@@ -2131,8 +2130,13 @@ def load_components(self, names: Optional[Union[List[str], str]] = None, **kwarg
21312130
component_load_kwargs[key] = value["default"]
21322131
try:
21332132
components_to_register[name] = spec.load(**component_load_kwargs)
2134-
except Exception as e:
2135-
logger.warning(f"Failed to create component '{name}': {e}")
2133+
except Exception:
2134+
logger.warning(
2135+
f"\nFailed to create component {name}:\n"
2136+
f"- Component spec: {spec}\n"
2137+
f"- load() called with kwargs: {component_load_kwargs}\n\n"
2138+
f"{traceback.format_exc()}"
2139+
)
21362140

21372141
# Register all components at once
21382142
self.register_components(**components_to_register)

0 commit comments

Comments
 (0)