Skip to content

Commit a88d11b

Browse files
committed
resolve conflicts.
2 parents a9165eb + 58f3771 commit a88d11b

8 files changed

Lines changed: 176 additions & 25 deletions

File tree

docs/source/en/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,6 @@
529529
title: Kandinsky 2.2
530530
- local: api/pipelines/kandinsky3
531531
title: Kandinsky 3
532-
- local: api/pipelines/kandinsky5
533-
title: Kandinsky 5
534532
- local: api/pipelines/kolors
535533
title: Kolors
536534
- local: api/pipelines/latent_consistency_models
@@ -638,6 +636,8 @@
638636
title: HunyuanVideo
639637
- local: api/pipelines/i2vgenxl
640638
title: I2VGen-XL
639+
- local: api/pipelines/kandinsky5_video
640+
title: Kandinsky 5.0 Video
641641
- local: api/pipelines/latte
642642
title: Latte
643643
- local: api/pipelines/ltx_video

docs/source/en/api/pipelines/kandinsky5.md renamed to docs/source/en/api/pipelines/kandinsky5_video.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
77
specific language governing permissions and limitations under the License.
88
-->
99

10-
# Kandinsky 5.0
10+
# Kandinsky 5.0 Video
1111

12-
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
12+
Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
1313

1414

1515
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
@@ -92,7 +92,7 @@ pipe = pipe.to("cuda")
9292

9393
pipe.transformer.set_attention_backend(
9494
"flex"
95-
) # <--- Set attention backend to Flex
95+
) # <--- Sett attention bakend to Flex
9696
pipe.transformer.compile(
9797
mode="max-autotune-no-cudagraphs",
9898
dynamic=True
@@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9)
115115
```
116116

117117
### Diffusion Distilled model
118-
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
118+
**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
119119

120120
```python
121121
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"

examples/unconditional_image_generation/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways:
104104
- you can either provide your own folder as `--train_data_dir`
105105
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
106106

107+
If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
108+
107109
Below, we explain both in more detail.
108110

109111
#### Provide the dataset as a folder

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
5252
return res.expand(broadcast_shape)
5353

5454

55+
def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
56+
"""
57+
Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
58+
"""
59+
if tensor.ndim == 2:
60+
tensor = tensor.unsqueeze(0)
61+
channels = tensor.shape[0]
62+
if channels == 3:
63+
return tensor
64+
if channels == 1:
65+
return tensor.repeat(3, 1, 1)
66+
if channels == 2:
67+
return torch.cat([tensor, tensor[:1]], dim=0)
68+
if channels > 3:
69+
return tensor[:3]
70+
raise ValueError(f"Unsupported number of channels: {channels}")
71+
72+
5573
def parse_args():
5674
parser = argparse.ArgumentParser(description="Simple example of a training script.")
5775
parser.add_argument(
@@ -260,6 +278,11 @@ def parse_args():
260278
parser.add_argument(
261279
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
262280
)
281+
parser.add_argument(
282+
"--preserve_input_precision",
283+
action="store_true",
284+
help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
285+
)
263286

264287
args = parser.parse_args()
265288
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -453,19 +476,41 @@ def load_model_hook(models, input_dir):
453476
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
454477

455478
# Preprocessing the datasets and DataLoaders creation.
479+
spatial_augmentations = [
480+
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
481+
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
482+
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
483+
]
484+
456485
augmentations = transforms.Compose(
457-
[
458-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
459-
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
460-
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
486+
spatial_augmentations
487+
+ [
461488
transforms.ToTensor(),
462489
transforms.Normalize([0.5], [0.5]),
463490
]
464491
)
465492

493+
precision_augmentations = transforms.Compose(
494+
[
495+
transforms.PILToTensor(),
496+
transforms.Lambda(_ensure_three_channels),
497+
transforms.ConvertImageDtype(torch.float32),
498+
]
499+
+ spatial_augmentations
500+
+ [transforms.Normalize([0.5], [0.5])]
501+
)
502+
466503
def transform_images(examples):
467-
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
468-
return {"input": images}
504+
processed = []
505+
for image in examples["image"]:
506+
if not args.preserve_input_precision:
507+
processed.append(augmentations(image.convert("RGB")))
508+
else:
509+
precise_image = image
510+
if precise_image.mode == "P":
511+
precise_image = precise_image.convert("RGB")
512+
processed.append(precision_augmentations(precise_image))
513+
return {"input": processed}
469514

470515
logger.info(f"Dataset size: {len(dataset)}")
471516

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,6 +2213,10 @@ def convert_key(key: str) -> str:
22132213

22142214
state_dict = {convert_key(k): v for k, v in state_dict.items()}
22152215

2216+
has_default = any("default." in k for k in state_dict)
2217+
if has_default:
2218+
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
2219+
22162220
converted_state_dict = {}
22172221
all_keys = list(state_dict.keys())
22182222
down_key = ".lora_down.weight"

src/diffusers/loaders/lora_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4940,7 +4940,8 @@ def lora_state_dict(
49404940
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
49414941
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
49424942
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
4943-
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
4943+
has_default = any("default." in k for k in state_dict)
4944+
if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
49444945
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
49454946

49464947
out = (state_dict, metadata) if return_lora_metadata else state_dict

src/diffusers/models/attention_dispatch.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,86 @@ def _(
649649
# ===== Helper functions to use attention backends with templated CP autograd functions =====
650650

651651

652+
def _native_attention_forward_op(
653+
ctx: torch.autograd.function.FunctionCtx,
654+
query: torch.Tensor,
655+
key: torch.Tensor,
656+
value: torch.Tensor,
657+
attn_mask: Optional[torch.Tensor] = None,
658+
dropout_p: float = 0.0,
659+
is_causal: bool = False,
660+
scale: Optional[float] = None,
661+
enable_gqa: bool = False,
662+
return_lse: bool = False,
663+
_save_ctx: bool = True,
664+
_parallel_config: Optional["ParallelConfig"] = None,
665+
):
666+
# Native attention does not return_lse
667+
if return_lse:
668+
raise ValueError("Native attention does not support return_lse=True")
669+
670+
# used for backward pass
671+
if _save_ctx:
672+
ctx.save_for_backward(query, key, value)
673+
ctx.attn_mask = attn_mask
674+
ctx.dropout_p = dropout_p
675+
ctx.is_causal = is_causal
676+
ctx.scale = scale
677+
ctx.enable_gqa = enable_gqa
678+
679+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
680+
out = torch.nn.functional.scaled_dot_product_attention(
681+
query=query,
682+
key=key,
683+
value=value,
684+
attn_mask=attn_mask,
685+
dropout_p=dropout_p,
686+
is_causal=is_causal,
687+
scale=scale,
688+
enable_gqa=enable_gqa,
689+
)
690+
out = out.permute(0, 2, 1, 3)
691+
692+
return out
693+
694+
695+
def _native_attention_backward_op(
696+
ctx: torch.autograd.function.FunctionCtx,
697+
grad_out: torch.Tensor,
698+
*args,
699+
**kwargs,
700+
):
701+
query, key, value = ctx.saved_tensors
702+
703+
query.requires_grad_(True)
704+
key.requires_grad_(True)
705+
value.requires_grad_(True)
706+
707+
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
708+
out = torch.nn.functional.scaled_dot_product_attention(
709+
query=query_t,
710+
key=key_t,
711+
value=value_t,
712+
attn_mask=ctx.attn_mask,
713+
dropout_p=ctx.dropout_p,
714+
is_causal=ctx.is_causal,
715+
scale=ctx.scale,
716+
enable_gqa=ctx.enable_gqa,
717+
)
718+
out = out.permute(0, 2, 1, 3)
719+
720+
grad_out_t = grad_out.permute(0, 2, 1, 3)
721+
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
722+
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
723+
)
724+
725+
grad_query = grad_query_t.permute(0, 2, 1, 3)
726+
grad_key = grad_key_t.permute(0, 2, 1, 3)
727+
grad_value = grad_value_t.permute(0, 2, 1, 3)
728+
729+
return grad_query, grad_key, grad_value
730+
731+
652732
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
653733
# forward declaration:
654734
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
@@ -1523,6 +1603,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
15231603
@_AttentionBackendRegistry.register(
15241604
AttentionBackendName.NATIVE,
15251605
constraints=[_check_device, _check_shape],
1606+
supports_context_parallel=True,
15261607
)
15271608
def _native_attention(
15281609
query: torch.Tensor,
@@ -1538,18 +1619,35 @@ def _native_attention(
15381619
) -> torch.Tensor:
15391620
if return_lse:
15401621
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
1541-
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
1542-
out = torch.nn.functional.scaled_dot_product_attention(
1543-
query=query,
1544-
key=key,
1545-
value=value,
1546-
attn_mask=attn_mask,
1547-
dropout_p=dropout_p,
1548-
is_causal=is_causal,
1549-
scale=scale,
1550-
enable_gqa=enable_gqa,
1551-
)
1552-
out = out.permute(0, 2, 1, 3)
1622+
if _parallel_config is None:
1623+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
1624+
out = torch.nn.functional.scaled_dot_product_attention(
1625+
query=query,
1626+
key=key,
1627+
value=value,
1628+
attn_mask=attn_mask,
1629+
dropout_p=dropout_p,
1630+
is_causal=is_causal,
1631+
scale=scale,
1632+
enable_gqa=enable_gqa,
1633+
)
1634+
out = out.permute(0, 2, 1, 3)
1635+
else:
1636+
out = _templated_context_parallel_attention(
1637+
query,
1638+
key,
1639+
value,
1640+
attn_mask,
1641+
dropout_p,
1642+
is_causal,
1643+
scale,
1644+
enable_gqa,
1645+
return_lse,
1646+
forward_op=_native_attention_forward_op,
1647+
backward_op=_native_attention_backward_op,
1648+
_parallel_config=_parallel_config,
1649+
)
1650+
15531651
return out
15541652

15551653

src/diffusers/utils/dynamic_modules_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def get_cached_module_file(
358358
proxies=proxies,
359359
local_files_only=local_files_only,
360360
local_dir=local_dir,
361+
revision=revision,
361362
token=token,
362363
)
363364
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))

0 commit comments

Comments
 (0)