Skip to content

Commit a92ed87

Browse files
authored
Merge branch 'main' into pipeline-specific-mixins
2 parents efadb7a + 1cdb872 commit a92ed87

33 files changed

Lines changed: 3706 additions & 1211 deletions

docs/source/en/api/pipelines/cosmos.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ output.save("output.png")
7070
- all
7171
- __call__
7272

73+
## Cosmos2_5_PredictBasePipeline
74+
75+
[[autodoc]] Cosmos2_5_PredictBasePipeline
76+
- all
77+
- __call__
78+
7379
## CosmosPipelineOutput
7480

7581
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

examples/community/pipeline_hunyuandit_differential_img2img.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
BertModel,
2222
BertTokenizer,
2323
CLIPImageProcessor,
24-
MT5Tokenizer,
2524
T5EncoderModel,
25+
T5Tokenizer,
2626
)
2727

2828
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
260260
The HunyuanDiT model designed by Tencent Hunyuan.
261261
text_encoder_2 (`T5EncoderModel`):
262262
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
263-
tokenizer_2 (`MT5Tokenizer`):
263+
tokenizer_2 (`T5Tokenizer`):
264264
The tokenizer for the mT5 embedder.
265265
scheduler ([`DDPMScheduler`]):
266266
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -295,7 +295,7 @@ def __init__(
295295
feature_extractor: CLIPImageProcessor,
296296
requires_safety_checker: bool = True,
297297
text_encoder_2=T5EncoderModel,
298-
tokenizer_2=MT5Tokenizer,
298+
tokenizer_2=T5Tokenizer,
299299
):
300300
super().__init__()
301301

scripts/convert_cosmos_to_diffusers.py

Lines changed: 185 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,94 @@
1+
"""
2+
# Cosmos 2 Predict
3+
4+
Download checkpoint
5+
```bash
6+
hf download nvidia/Cosmos-Predict2-2B-Text2Image
7+
```
8+
9+
convert checkpoint
10+
```bash
11+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
12+
13+
python scripts/convert_cosmos_to_diffusers.py \
14+
--transformer_ckpt_path $transformer_ckpt_path \
15+
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
16+
--text_encoder_path google-t5/t5-11b \
17+
--tokenizer_path google-t5/t5-11b \
18+
--vae_type wan2.1 \
19+
--output_path converted/cosmos-p2-t2i-2b \
20+
--save_pipeline
21+
```
22+
23+
# Cosmos 2.5 Predict
24+
25+
Download checkpoint
26+
```bash
27+
hf download nvidia/Cosmos-Predict2.5-2B
28+
```
29+
30+
Convert checkpoint
31+
```bash
32+
# pre-trained
33+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
34+
35+
python scripts/convert_cosmos_to_diffusers.py \
36+
--transformer_type Cosmos-2.5-Predict-Base-2B \
37+
--transformer_ckpt_path $transformer_ckpt_path \
38+
--vae_type wan2.1 \
39+
--output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \
40+
--save_pipeline
41+
42+
# post-trained
43+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt
44+
45+
python scripts/convert_cosmos_to_diffusers.py \
46+
--transformer_type Cosmos-2.5-Predict-Base-2B \
47+
--transformer_ckpt_path $transformer_ckpt_path \
48+
--vae_type wan2.1 \
49+
--output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \
50+
--save_pipeline
51+
```
52+
53+
## 14B
54+
55+
```bash
56+
hf download nvidia/Cosmos-Predict2.5-14B
57+
```
58+
59+
```bash
60+
# pre-trained
61+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt
62+
63+
python scripts/convert_cosmos_to_diffusers.py \
64+
--transformer_type Cosmos-2.5-Predict-Base-14B \
65+
--transformer_ckpt_path $transformer_ckpt_path \
66+
--vae_type wan2.1 \
67+
--output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \
68+
--save_pipeline
69+
70+
# post-trained
71+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt
72+
73+
python scripts/convert_cosmos_to_diffusers.py \
74+
--transformer_type Cosmos-2.5-Predict-Base-14B \
75+
--transformer_ckpt_path $transformer_ckpt_path \
76+
--vae_type wan2.1 \
77+
--output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \
78+
--save_pipeline
79+
```
80+
81+
"""
82+
183
import argparse
284
import pathlib
85+
import sys
386
from typing import Any, Dict
487

588
import torch
689
from accelerate import init_empty_weights
790
from huggingface_hub import snapshot_download
8-
from transformers import T5EncoderModel, T5TokenizerFast
91+
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast
992

1093
from diffusers import (
1194
AutoencoderKLCosmos,
@@ -17,7 +100,9 @@
17100
CosmosVideoToWorldPipeline,
18101
EDMEulerScheduler,
19102
FlowMatchEulerDiscreteScheduler,
103+
UniPCMultistepScheduler,
20104
)
105+
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
21106

22107

23108
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -233,6 +318,44 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
233318
"concat_padding_mask": True,
234319
"extra_pos_embed_type": None,
235320
},
321+
"Cosmos-2.5-Predict-Base-2B": {
322+
"in_channels": 16 + 1,
323+
"out_channels": 16,
324+
"num_attention_heads": 16,
325+
"attention_head_dim": 128,
326+
"num_layers": 28,
327+
"mlp_ratio": 4.0,
328+
"text_embed_dim": 1024,
329+
"adaln_lora_dim": 256,
330+
"max_size": (128, 240, 240),
331+
"patch_size": (1, 2, 2),
332+
"rope_scale": (1.0, 3.0, 3.0),
333+
"concat_padding_mask": True,
334+
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
335+
"extra_pos_embed_type": None,
336+
"use_crossattn_projection": True,
337+
"crossattn_proj_in_channels": 100352,
338+
"encoder_hidden_states_channels": 1024,
339+
},
340+
"Cosmos-2.5-Predict-Base-14B": {
341+
"in_channels": 16 + 1,
342+
"out_channels": 16,
343+
"num_attention_heads": 40,
344+
"attention_head_dim": 128,
345+
"num_layers": 36,
346+
"mlp_ratio": 4.0,
347+
"text_embed_dim": 1024,
348+
"adaln_lora_dim": 256,
349+
"max_size": (128, 240, 240),
350+
"patch_size": (1, 2, 2),
351+
"rope_scale": (1.0, 3.0, 3.0),
352+
"concat_padding_mask": True,
353+
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
354+
"extra_pos_embed_type": None,
355+
"use_crossattn_projection": True,
356+
"crossattn_proj_in_channels": 100352,
357+
"encoder_hidden_states_channels": 1024,
358+
},
236359
}
237360

238361
VAE_KEYS_RENAME_DICT = {
@@ -334,6 +457,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
334457
elif "Cosmos-2.0" in transformer_type:
335458
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
336459
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
460+
elif "Cosmos-2.5" in transformer_type:
461+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
462+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
337463
else:
338464
assert False
339465

@@ -347,6 +473,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
347473
new_key = new_key.removeprefix(PREFIX_KEY)
348474
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
349475
new_key = new_key.replace(replace_key, rename_key)
476+
print(key, "->", new_key, flush=True)
350477
update_state_dict_(original_state_dict, key, new_key)
351478

352479
for key in list(original_state_dict.keys()):
@@ -355,6 +482,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
355482
continue
356483
handler_fn_inplace(key, original_state_dict)
357484

485+
expected_keys = set(transformer.state_dict().keys())
486+
mapped_keys = set(original_state_dict.keys())
487+
missing_keys = expected_keys - mapped_keys
488+
unexpected_keys = mapped_keys - expected_keys
489+
if missing_keys:
490+
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
491+
for k in missing_keys:
492+
print(k)
493+
sys.exit(1)
494+
if unexpected_keys:
495+
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
496+
for k in unexpected_keys:
497+
print(k)
498+
sys.exit(2)
499+
358500
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
359501
return transformer
360502

@@ -444,17 +586,45 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
444586
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
445587

446588

589+
def save_pipeline_cosmos2_5(args, transformer, vae):
590+
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
591+
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
592+
593+
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
594+
text_encoder_path, torch_dtype="auto", device_map="cpu"
595+
)
596+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
597+
598+
scheduler = UniPCMultistepScheduler(
599+
use_karras_sigmas=True,
600+
use_flow_sigmas=True,
601+
prediction_type="flow_prediction",
602+
sigma_max=200.0,
603+
sigma_min=0.01,
604+
)
605+
606+
pipe = Cosmos2_5_PredictBasePipeline(
607+
text_encoder=text_encoder,
608+
tokenizer=tokenizer,
609+
transformer=transformer,
610+
vae=vae,
611+
scheduler=scheduler,
612+
safety_checker=lambda *args, **kwargs: None,
613+
)
614+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
615+
616+
447617
def get_args():
448618
parser = argparse.ArgumentParser()
449619
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
450620
parser.add_argument(
451621
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
452622
)
453623
parser.add_argument(
454-
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
624+
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
455625
)
456-
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
457-
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
626+
parser.add_argument("--text_encoder_path", type=str, default=None)
627+
parser.add_argument("--tokenizer_path", type=str, default=None)
458628
parser.add_argument("--save_pipeline", action="store_true")
459629
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
460630
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
@@ -477,8 +647,6 @@ def get_args():
477647
if args.save_pipeline:
478648
assert args.transformer_ckpt_path is not None
479649
assert args.vae_type is not None
480-
assert args.text_encoder_path is not None
481-
assert args.tokenizer_path is not None
482650

483651
if args.transformer_ckpt_path is not None:
484652
weights_only = "Cosmos-1.0" in args.transformer_type
@@ -490,17 +658,26 @@ def get_args():
490658
if args.vae_type is not None:
491659
if "Cosmos-1.0" in args.transformer_type:
492660
vae = convert_vae(args.vae_type)
493-
else:
661+
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
494662
vae = AutoencoderKLWan.from_pretrained(
495663
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
496664
)
665+
else:
666+
raise AssertionError(f"{args.transformer_type} not supported")
667+
497668
if not args.save_pipeline:
498669
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
499670

500671
if args.save_pipeline:
501672
if "Cosmos-1.0" in args.transformer_type:
673+
assert args.text_encoder_path is not None
674+
assert args.tokenizer_path is not None
502675
save_pipeline_cosmos_1_0(args, transformer, vae)
503676
elif "Cosmos-2.0" in args.transformer_type:
677+
assert args.text_encoder_path is not None
678+
assert args.tokenizer_path is not None
504679
save_pipeline_cosmos_2_0(args, transformer, vae)
680+
elif "Cosmos-2.5" in args.transformer_type:
681+
save_pipeline_cosmos2_5(args, transformer, vae)
505682
else:
506-
assert False
683+
raise AssertionError(f"{args.transformer_type} not supported")

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@
463463
"CogView4ControlPipeline",
464464
"CogView4Pipeline",
465465
"ConsisIDPipeline",
466+
"Cosmos2_5_PredictBasePipeline",
466467
"Cosmos2TextToImagePipeline",
467468
"Cosmos2VideoToWorldPipeline",
468469
"CosmosTextToWorldPipeline",
@@ -674,6 +675,7 @@
674675
"ZImageControlNetInpaintPipeline",
675676
"ZImageControlNetPipeline",
676677
"ZImageImg2ImgPipeline",
678+
"ZImageOmniPipeline",
677679
"ZImagePipeline",
678680
]
679681
)
@@ -1175,6 +1177,7 @@
11751177
CogView4ControlPipeline,
11761178
CogView4Pipeline,
11771179
ConsisIDPipeline,
1180+
Cosmos2_5_PredictBasePipeline,
11781181
Cosmos2TextToImagePipeline,
11791182
Cosmos2VideoToWorldPipeline,
11801183
CosmosTextToWorldPipeline,
@@ -1384,6 +1387,7 @@
13841387
ZImageControlNetInpaintPipeline,
13851388
ZImageControlNetPipeline,
13861389
ZImageImg2ImgPipeline,
1390+
ZImageOmniPipeline,
13871391
ZImagePipeline,
13881392
)
13891393

src/diffusers/guiders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
2626
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
2727
from .guider_utils import BaseGuidance
28+
from .magnitude_aware_guidance import MagnitudeAwareGuidance
2829
from .perturbed_attention_guidance import PerturbedAttentionGuidance
2930
from .skip_layer_guidance import SkipLayerGuidance
3031
from .smoothed_energy_guidance import SmoothedEnergyGuidance

0 commit comments

Comments
 (0)