1+ """
2+ # Cosmos 2 Predict
3+
4+ Download checkpoint
5+ ```bash
6+ hf download nvidia/Cosmos-Predict2-2B-Text2Image
7+ ```
8+
9+ convert checkpoint
10+ ```bash
11+ transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
12+
13+ python scripts/convert_cosmos_to_diffusers.py \
14+ --transformer_ckpt_path $transformer_ckpt_path \
15+ --transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
16+ --text_encoder_path google-t5/t5-11b \
17+ --tokenizer_path google-t5/t5-11b \
18+ --vae_type wan2.1 \
19+ --output_path converted/cosmos-p2-t2i-2b \
20+ --save_pipeline
21+ ```
22+
23+ # Cosmos 2.5 Predict
24+
25+ Download checkpoint
26+ ```bash
27+ hf download nvidia/Cosmos-Predict2.5-2B
28+ ```
29+
30+ Convert checkpoint
31+ ```bash
32+ # pre-trained
33+ transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
34+
35+ python scripts/convert_cosmos_to_diffusers.py \
36+ --transformer_type Cosmos-2.5-Predict-Base-2B \
37+ --transformer_ckpt_path $transformer_ckpt_path \
38+ --vae_type wan2.1 \
39+ --output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \
40+ --save_pipeline
41+
42+ # post-trained
43+ transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt
44+
45+ python scripts/convert_cosmos_to_diffusers.py \
46+ --transformer_type Cosmos-2.5-Predict-Base-2B \
47+ --transformer_ckpt_path $transformer_ckpt_path \
48+ --vae_type wan2.1 \
49+ --output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \
50+ --save_pipeline
51+ ```
52+
53+ ## 14B
54+
55+ ```bash
56+ hf download nvidia/Cosmos-Predict2.5-14B
57+ ```
58+
59+ ```bash
60+ # pre-trained
61+ transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt
62+
63+ python scripts/convert_cosmos_to_diffusers.py \
64+ --transformer_type Cosmos-2.5-Predict-Base-14B \
65+ --transformer_ckpt_path $transformer_ckpt_path \
66+ --vae_type wan2.1 \
67+ --output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \
68+ --save_pipeline
69+
70+ # post-trained
71+ transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt
72+
73+ python scripts/convert_cosmos_to_diffusers.py \
74+ --transformer_type Cosmos-2.5-Predict-Base-14B \
75+ --transformer_ckpt_path $transformer_ckpt_path \
76+ --vae_type wan2.1 \
77+ --output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \
78+ --save_pipeline
79+ ```
80+
81+ """
82+
183import argparse
284import pathlib
85+ import sys
386from typing import Any , Dict
487
588import torch
689from accelerate import init_empty_weights
790from huggingface_hub import snapshot_download
8- from transformers import T5EncoderModel , T5TokenizerFast
91+ from transformers import AutoTokenizer , Qwen2_5_VLForConditionalGeneration , T5EncoderModel , T5TokenizerFast
992
1093from diffusers import (
1194 AutoencoderKLCosmos ,
17100 CosmosVideoToWorldPipeline ,
18101 EDMEulerScheduler ,
19102 FlowMatchEulerDiscreteScheduler ,
103+ UniPCMultistepScheduler ,
20104)
105+ from diffusers .pipelines .cosmos .pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
21106
22107
23108def remove_keys_ (key : str , state_dict : Dict [str , Any ]):
@@ -233,6 +318,44 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
233318 "concat_padding_mask" : True ,
234319 "extra_pos_embed_type" : None ,
235320 },
321+ "Cosmos-2.5-Predict-Base-2B" : {
322+ "in_channels" : 16 + 1 ,
323+ "out_channels" : 16 ,
324+ "num_attention_heads" : 16 ,
325+ "attention_head_dim" : 128 ,
326+ "num_layers" : 28 ,
327+ "mlp_ratio" : 4.0 ,
328+ "text_embed_dim" : 1024 ,
329+ "adaln_lora_dim" : 256 ,
330+ "max_size" : (128 , 240 , 240 ),
331+ "patch_size" : (1 , 2 , 2 ),
332+ "rope_scale" : (1.0 , 3.0 , 3.0 ),
333+ "concat_padding_mask" : True ,
334+ # NOTE: source config has pos_emb_learnable: 'True' - but params are missing
335+ "extra_pos_embed_type" : None ,
336+ "use_crossattn_projection" : True ,
337+ "crossattn_proj_in_channels" : 100352 ,
338+ "encoder_hidden_states_channels" : 1024 ,
339+ },
340+ "Cosmos-2.5-Predict-Base-14B" : {
341+ "in_channels" : 16 + 1 ,
342+ "out_channels" : 16 ,
343+ "num_attention_heads" : 40 ,
344+ "attention_head_dim" : 128 ,
345+ "num_layers" : 36 ,
346+ "mlp_ratio" : 4.0 ,
347+ "text_embed_dim" : 1024 ,
348+ "adaln_lora_dim" : 256 ,
349+ "max_size" : (128 , 240 , 240 ),
350+ "patch_size" : (1 , 2 , 2 ),
351+ "rope_scale" : (1.0 , 3.0 , 3.0 ),
352+ "concat_padding_mask" : True ,
353+ # NOTE: source config has pos_emb_learnable: 'True' - but params are missing
354+ "extra_pos_embed_type" : None ,
355+ "use_crossattn_projection" : True ,
356+ "crossattn_proj_in_channels" : 100352 ,
357+ "encoder_hidden_states_channels" : 1024 ,
358+ },
236359}
237360
238361VAE_KEYS_RENAME_DICT = {
@@ -334,6 +457,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
334457 elif "Cosmos-2.0" in transformer_type :
335458 TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
336459 TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
460+ elif "Cosmos-2.5" in transformer_type :
461+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
462+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
337463 else :
338464 assert False
339465
@@ -347,6 +473,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
347473 new_key = new_key .removeprefix (PREFIX_KEY )
348474 for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
349475 new_key = new_key .replace (replace_key , rename_key )
476+ print (key , "->" , new_key , flush = True )
350477 update_state_dict_ (original_state_dict , key , new_key )
351478
352479 for key in list (original_state_dict .keys ()):
@@ -355,6 +482,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
355482 continue
356483 handler_fn_inplace (key , original_state_dict )
357484
485+ expected_keys = set (transformer .state_dict ().keys ())
486+ mapped_keys = set (original_state_dict .keys ())
487+ missing_keys = expected_keys - mapped_keys
488+ unexpected_keys = mapped_keys - expected_keys
489+ if missing_keys :
490+ print (f"ERROR: missing keys ({ len (missing_keys )} from state_dict:" , flush = True , file = sys .stderr )
491+ for k in missing_keys :
492+ print (k )
493+ sys .exit (1 )
494+ if unexpected_keys :
495+ print (f"ERROR: unexpected keys ({ len (unexpected_keys )} ) from state_dict:" , flush = True , file = sys .stderr )
496+ for k in unexpected_keys :
497+ print (k )
498+ sys .exit (2 )
499+
358500 transformer .load_state_dict (original_state_dict , strict = True , assign = True )
359501 return transformer
360502
@@ -444,17 +586,45 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
444586 pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
445587
446588
589+ def save_pipeline_cosmos2_5 (args , transformer , vae ):
590+ text_encoder_path = args .text_encoder_path or "nvidia/Cosmos-Reason1-7B"
591+ tokenizer_path = args .tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
592+
593+ text_encoder = Qwen2_5_VLForConditionalGeneration .from_pretrained (
594+ text_encoder_path , torch_dtype = "auto" , device_map = "cpu"
595+ )
596+ tokenizer = AutoTokenizer .from_pretrained (tokenizer_path )
597+
598+ scheduler = UniPCMultistepScheduler (
599+ use_karras_sigmas = True ,
600+ use_flow_sigmas = True ,
601+ prediction_type = "flow_prediction" ,
602+ sigma_max = 200.0 ,
603+ sigma_min = 0.01 ,
604+ )
605+
606+ pipe = Cosmos2_5_PredictBasePipeline (
607+ text_encoder = text_encoder ,
608+ tokenizer = tokenizer ,
609+ transformer = transformer ,
610+ vae = vae ,
611+ scheduler = scheduler ,
612+ safety_checker = lambda * args , ** kwargs : None ,
613+ )
614+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
615+
616+
447617def get_args ():
448618 parser = argparse .ArgumentParser ()
449619 parser .add_argument ("--transformer_type" , type = str , default = None , choices = list (TRANSFORMER_CONFIGS .keys ()))
450620 parser .add_argument (
451621 "--transformer_ckpt_path" , type = str , default = None , help = "Path to original transformer checkpoint"
452622 )
453623 parser .add_argument (
454- "--vae_type" , type = str , default = None , choices = ["none " , * list (VAE_CONFIGS .keys ())], help = "Type of VAE"
624+ "--vae_type" , type = str , default = "wan2.1" , choices = ["wan2.1 " , * list (VAE_CONFIGS .keys ())], help = "Type of VAE"
455625 )
456- parser .add_argument ("--text_encoder_path" , type = str , default = "google-t5/t5-11b" )
457- parser .add_argument ("--tokenizer_path" , type = str , default = "google-t5/t5-11b" )
626+ parser .add_argument ("--text_encoder_path" , type = str , default = None )
627+ parser .add_argument ("--tokenizer_path" , type = str , default = None )
458628 parser .add_argument ("--save_pipeline" , action = "store_true" )
459629 parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
460630 parser .add_argument ("--dtype" , default = "bf16" , help = "Torch dtype to save the transformer in." )
@@ -477,8 +647,6 @@ def get_args():
477647 if args .save_pipeline :
478648 assert args .transformer_ckpt_path is not None
479649 assert args .vae_type is not None
480- assert args .text_encoder_path is not None
481- assert args .tokenizer_path is not None
482650
483651 if args .transformer_ckpt_path is not None :
484652 weights_only = "Cosmos-1.0" in args .transformer_type
@@ -490,17 +658,26 @@ def get_args():
490658 if args .vae_type is not None :
491659 if "Cosmos-1.0" in args .transformer_type :
492660 vae = convert_vae (args .vae_type )
493- else :
661+ elif "Cosmos-2.0" in args . transformer_type or "Cosmos-2.5" in args . transformer_type :
494662 vae = AutoencoderKLWan .from_pretrained (
495663 "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" , subfolder = "vae" , torch_dtype = torch .float32
496664 )
665+ else :
666+ raise AssertionError (f"{ args .transformer_type } not supported" )
667+
497668 if not args .save_pipeline :
498669 vae .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" )
499670
500671 if args .save_pipeline :
501672 if "Cosmos-1.0" in args .transformer_type :
673+ assert args .text_encoder_path is not None
674+ assert args .tokenizer_path is not None
502675 save_pipeline_cosmos_1_0 (args , transformer , vae )
503676 elif "Cosmos-2.0" in args .transformer_type :
677+ assert args .text_encoder_path is not None
678+ assert args .tokenizer_path is not None
504679 save_pipeline_cosmos_2_0 (args , transformer , vae )
680+ elif "Cosmos-2.5" in args .transformer_type :
681+ save_pipeline_cosmos2_5 (args , transformer , vae )
505682 else :
506- assert False
683+ raise AssertionError ( f" { args . transformer_type } not supported" )
0 commit comments