Skip to content

Commit c739ee9

Browse files
author
yiyi@huggingface.co
committed
update conversion script
1 parent 76bb607 commit c739ee9

1 file changed

Lines changed: 95 additions & 47 deletions

File tree

scripts/convert_hunyuan_video1_5_to_diffusers.py

Lines changed: 95 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
"""
22
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
3-
--original_state_dict_folder /raid/yiyi/new-model-vid \
4-
--output_transformer_path /raid/yiyi/hunyuanvideo15-480p_i2v-diffusers \
3+
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
4+
--output_path /fsx/yiyi/hy15/480p_i2v\
55
--transformer_type 480p_i2v \
66
--dtype fp32
77
"""
88

99
"""
1010
python scripts/convert_hunyuan_video1_5_to_diffusers.py \
11-
--original_state_dict_folder /raid/yiyi/new-model-vid \
12-
--output_vae_path /raid/yiyi/hunyuanvideo15-vae \
13-
--dtype fp32
11+
--original_state_dict_repo_id tencent/HunyuanVideo-1.5\
12+
--output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers \
13+
--dtype bf16 \
14+
--save_pipeline \
15+
--byt5_path /fsx/yiyi/hy15/text_encoder/Glyph-SDXL-v2\
16+
--transformer_type 480p_i2v
1417
"""
1518

1619
import argparse
@@ -22,11 +25,12 @@
2225
from huggingface_hub import snapshot_download, hf_hub_download
2326

2427
import pathlib
25-
from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15
28+
from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15, FlowMatchEulerDiscreteScheduler, ClassifierFreeGuidance, HunyuanVideo15Pipeline
2629
from transformers import AutoModel, AutoTokenizer, T5EncoderModel, ByT5Tokenizer
2730

2831
import json
2932
import argparse
33+
import os
3034

3135
TRANSFORMER_CONFIGS = {
3236
"480p_i2v": {
@@ -49,6 +53,20 @@
4953
},
5054
}
5155

56+
SCHEDULER_CONFIGS = {
57+
"480p_i2v": {
58+
"shift": 5.0,
59+
},
60+
}
61+
62+
GUIDANCE_CONFIGS = {
63+
"480p_i2v": {
64+
"guidance_scale": 6.0,
65+
"embedded_guidance_scale": None,
66+
},
67+
68+
}
69+
5270
def swap_scale_shift(weight):
5371
shift, scale = weight.chunk(2, dim=0)
5472
new_weight = torch.cat([scale, shift], dim=0)
@@ -571,18 +589,16 @@ def convert_vae(args):
571589
vae.load_state_dict(state_dict, strict=True, assign=True)
572590
return vae
573591

574-
def save_text_encoder(output_path):
592+
def load_mllm():
593+
print(f" loading from Qwen/Qwen2.5-VL-7B-Instruct")
575594
text_encoder = AutoModel.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", low_cpu_mem_usage=True)
576595
if hasattr(text_encoder, 'language_model'):
577596
text_encoder = text_encoder.language_model
578-
579-
580-
text_encoder.save_pretrained(output_path + "/text_encoder")
581-
582597
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right")
583-
tokenizer.save_pretrained(output_path + "/tokenizer")
598+
return text_encoder, tokenizer
584599

585600

601+
#copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89
586602
def add_special_token(
587603
tokenizer,
588604
text_encoder,
@@ -625,42 +641,36 @@ def add_special_token(
625641
text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)
626642

627643

628-
def save_text_encoder_2(
629-
byt5_base_path,
630-
byt5_checkpoint_path,
631-
color_ann_path,
632-
font_ann_path,
633-
output_path,
634-
multilingual=True
635-
):
644+
645+
646+
def load_byt5(args):
636647
"""
637648
Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format.
638-
639-
Args:
640-
byt5_base_path: Path to base byt5-small model (e.g., "google/byt5-small")
641-
byt5_checkpoint_path: Path to Glyph-SDXL-v2 checkpoint (byt5_model.pt)
642-
color_ann_path: Path to color_idx.json
643-
font_ann_path: Path to multilingual_10-lang_idx.json
644-
output_path: Where to save the converted model
645-
multilingual: Whether to use multilingual font tokens
646649
"""
647650

648-
651+
649652
# 1. Load base tokenizer and encoder
650-
tokenizer = AutoTokenizer.from_pretrained(byt5_base_path)
653+
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
651654

652655
# Load as T5EncoderModel
653-
encoder = T5EncoderModel.from_pretrained(byt5_base_path)
656+
encoder = T5EncoderModel.from_pretrained("google/byt5-small")
654657

658+
byt5_checkpoint_path = os.path.join(args.byt5_path, "checkpoints/byt5_model.pt")
659+
color_ann_path = os.path.join(args.byt5_path, "assets/color_idx.json")
660+
font_ann_path = os.path.join(args.byt5_path, "assets/multilingual_10-lang_idx.json")
661+
655662
# 2. Add special tokens
656663
add_special_token(
657-
tokenizer,
658-
encoder,
664+
tokenizer=tokenizer,
665+
text_encoder=encoder,
666+
add_color=True,
667+
add_font=True,
659668
color_ann_path=color_ann_path,
660669
font_ann_path=font_ann_path,
661-
multilingual=multilingual
670+
multilingual=True,
662671
)
663672

673+
664674
# 3. Load Glyph-SDXL-v2 checkpoint
665675
print(f"\n3. Loading Glyph-SDXL-v2 checkpoint: {byt5_checkpoint_path}")
666676
checkpoint = torch.load(byt5_checkpoint_path, map_location='cpu')
@@ -694,11 +704,7 @@ def save_text_encoder_2(
694704
raise ValueError(f"Missing keys: {missing_keys}")
695705

696706

697-
# Save encoder
698-
encoder.save_pretrained(output_path + "/text_encoder_2")
699-
700-
# Save tokenizer
701-
tokenizer.save_pretrained(output_path + "/tokenizer_2")
707+
return encoder, tokenizer
702708

703709

704710
def get_args():
@@ -707,12 +713,26 @@ def get_args():
707713
"--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model"
708714
)
709715
parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict")
710-
parser.add_argument("--output_vae_path", type=str, default=None, help="Path where converted VAE should be saved")
711-
parser.add_argument("--output_transformer_path", type=str, default=None, help="Path where converted transformer should be saved")
716+
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved")
712717
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
713718
parser.add_argument(
714719
"--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys())
715720
)
721+
parser.add_argument(
722+
"--byt5_path",
723+
type=str,
724+
default=None,
725+
help=(
726+
"path to the downloaded byt5 checkpoint & assets. "
727+
"Note: They use Glyph-SDXL-v2 as byt5 encoder. You can download from modelscope like: "
728+
"`modelscope download --model AI-ModelScope/Glyph-SDXL-v2 --local_dir ./ckpts/text_encoder/Glyph-SDXL-v2` "
729+
"or manually download following the instructions on "
730+
"https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/checkpoints-download.md. "
731+
"The path should point to the Glyph-SDXL-v2 folder which should contain an `assets` folder and a `checkpoints` folder, "
732+
"like: Glyph-SDXL-v2/assets/... and Glyph-SDXL-v2/checkpoints/byt5_model.pt"
733+
),
734+
)
735+
parser.add_argument("--save_pipeline", action="store_true")
716736
return parser.parse_args()
717737

718738

@@ -726,16 +746,44 @@ def get_args():
726746
if __name__ == "__main__":
727747
args = get_args()
728748

749+
if args.save_pipeline and args.byt5_path is None:
750+
raise ValueError("Please provide --byt5_path when saving pipeline")
751+
729752
transformer = None
730753
dtype = DTYPE_MAPPING[args.dtype]
731754

732-
if args.output_transformer_path is not None:
733-
transformer = convert_transformer(args)
734-
transformer = transformer.to(dtype=dtype)
735-
transformer.save_pretrained(args.output_transformer_path, safe_serialization=True)
755+
transformer = convert_transformer(args)
756+
transformer = transformer.to(dtype=dtype)
757+
if not args.save_pipeline:
758+
transformer.save_pretrained(args.output_path, safe_serialization=True)
759+
else:
736760

737-
if args.output_vae_path is not None:
738761
vae = convert_vae(args)
739762
vae = vae.to(dtype=dtype)
740-
vae.save_pretrained(args.output_vae_path, safe_serialization=True)
763+
764+
765+
text_encoder, tokenizer = load_mllm()
766+
text_encoder_2, tokenizer_2 = load_byt5(args)
767+
text_encoder = text_encoder.to(dtype=dtype)
768+
text_encoder_2 = text_encoder_2.to(dtype=dtype)
769+
770+
flow_shift = SCHEDULER_CONFIGS[args.transformer_type]["shift"]
771+
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
772+
773+
guidance_scale = GUIDANCE_CONFIGS[args.transformer_type]["guidance_scale"]
774+
guider = ClassifierFreeGuidance(guidance_scale=guidance_scale)
775+
776+
pipeline = HunyuanVideo15Pipeline(
777+
vae=vae,
778+
text_encoder=text_encoder,
779+
text_encoder_2=text_encoder_2,
780+
tokenizer=tokenizer,
781+
tokenizer_2=tokenizer_2,
782+
transformer=transformer,
783+
guider=guider,
784+
scheduler=scheduler,
785+
)
786+
pipeline.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
787+
788+
741789

0 commit comments

Comments
 (0)