|
| 1 | +""" |
| 2 | +Convert Anima checkpoints to Diffusers format. |
| 3 | +
|
| 4 | +Example: |
| 5 | +```bash |
| 6 | +python scripts/convert_anima_to_diffusers.py \ |
| 7 | + --transformer_ckpt_path anima_model/anima-preview3-base.safetensors \ |
| 8 | + --text_encoder_ckpt_path anima_model/qwen_3_06b_base.safetensors \ |
| 9 | + --vae_ckpt_path anima_model/qwen_image_vae.safetensors \ |
| 10 | + --qwen_tokenizer_path path/to/qwen25_tokenizer \ |
| 11 | + --t5_tokenizer_path path/to/t5_tokenizer \ |
| 12 | + --output_path anima_model/anima-preview3-diffusers \ |
| 13 | + --save_pipeline |
| 14 | +``` |
| 15 | +""" |
| 16 | + |
| 17 | +import argparse |
| 18 | +import pathlib |
| 19 | +import sys |
| 20 | +from typing import Any |
| 21 | + |
| 22 | +import torch |
| 23 | +from accelerate import init_empty_weights |
| 24 | +from convert_cosmos_to_diffusers import convert_transformer |
| 25 | +from safetensors.torch import load_file |
| 26 | +from transformers import AutoTokenizer, Qwen3Config, Qwen3Model, T5TokenizerFast |
| 27 | + |
| 28 | +from diffusers import ( |
| 29 | + AnimaAutoBlocks, |
| 30 | + AnimaTextConditioner, |
| 31 | + AutoencoderKLQwenImage, |
| 32 | + FlowMatchEulerDiscreteScheduler, |
| 33 | +) |
| 34 | + |
| 35 | + |
| 36 | +DTYPE_MAPPING = { |
| 37 | + "fp32": torch.float32, |
| 38 | + "fp16": torch.float16, |
| 39 | + "bf16": torch.bfloat16, |
| 40 | +} |
| 41 | + |
| 42 | + |
| 43 | +def rename_residual_key(key: str) -> str: |
| 44 | + replacements = { |
| 45 | + ".residual.0.": ".norm1.", |
| 46 | + ".residual.2.": ".conv1.", |
| 47 | + ".residual.3.": ".norm2.", |
| 48 | + ".residual.6.": ".conv2.", |
| 49 | + ".shortcut.": ".conv_shortcut.", |
| 50 | + } |
| 51 | + for old, new in replacements.items(): |
| 52 | + key = key.replace(old, new) |
| 53 | + return key |
| 54 | + |
| 55 | + |
| 56 | +def rename_mid_key(key: str) -> str: |
| 57 | + replacements = { |
| 58 | + ".middle.0.": ".mid_block.resnets.0.", |
| 59 | + ".middle.1.": ".mid_block.attentions.0.", |
| 60 | + ".middle.2.": ".mid_block.resnets.1.", |
| 61 | + } |
| 62 | + for old, new in replacements.items(): |
| 63 | + key = key.replace(old, new) |
| 64 | + return rename_residual_key(key) |
| 65 | + |
| 66 | + |
| 67 | +def rename_decoder_upsample_key(key: str) -> str: |
| 68 | + prefix = "decoder.upsamples." |
| 69 | + suffix = key.removeprefix(prefix) |
| 70 | + index_str, rest = suffix.split(".", 1) |
| 71 | + index = int(index_str) |
| 72 | + |
| 73 | + if index in (3, 7, 11): |
| 74 | + block_index = (index - 3) // 4 |
| 75 | + new_key = f"decoder.up_blocks.{block_index}.upsamplers.0.{rest}" |
| 76 | + else: |
| 77 | + block_index = index // 4 |
| 78 | + resnet_index = index % 4 |
| 79 | + new_key = f"decoder.up_blocks.{block_index}.resnets.{resnet_index}.{rest}" |
| 80 | + |
| 81 | + return rename_residual_key(new_key) |
| 82 | + |
| 83 | + |
| 84 | +def convert_qwen_image_vae_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| 85 | + converted_state_dict = {} |
| 86 | + for key, value in state_dict.items(): |
| 87 | + if key.startswith("conv1."): |
| 88 | + new_key = key.replace("conv1.", "quant_conv.", 1) |
| 89 | + elif key.startswith("conv2."): |
| 90 | + new_key = key.replace("conv2.", "post_quant_conv.", 1) |
| 91 | + elif key.startswith("encoder.conv1."): |
| 92 | + new_key = key.replace("encoder.conv1.", "encoder.conv_in.", 1) |
| 93 | + elif key.startswith("decoder.conv1."): |
| 94 | + new_key = key.replace("decoder.conv1.", "decoder.conv_in.", 1) |
| 95 | + elif key.startswith("encoder.downsamples."): |
| 96 | + new_key = rename_residual_key(key.replace("encoder.downsamples.", "encoder.down_blocks.", 1)) |
| 97 | + elif key.startswith("decoder.upsamples."): |
| 98 | + new_key = rename_decoder_upsample_key(key) |
| 99 | + elif key.startswith("encoder.middle.") or key.startswith("decoder.middle."): |
| 100 | + new_key = rename_mid_key(key) |
| 101 | + elif key.startswith("encoder.head.0."): |
| 102 | + new_key = key.replace("encoder.head.0.", "encoder.norm_out.", 1) |
| 103 | + elif key.startswith("encoder.head.2."): |
| 104 | + new_key = key.replace("encoder.head.2.", "encoder.conv_out.", 1) |
| 105 | + elif key.startswith("decoder.head.0."): |
| 106 | + new_key = key.replace("decoder.head.0.", "decoder.norm_out.", 1) |
| 107 | + elif key.startswith("decoder.head.2."): |
| 108 | + new_key = key.replace("decoder.head.2.", "decoder.conv_out.", 1) |
| 109 | + else: |
| 110 | + new_key = rename_residual_key(key) |
| 111 | + |
| 112 | + if new_key in converted_state_dict: |
| 113 | + raise ValueError(f"Duplicate converted VAE key: {new_key}") |
| 114 | + converted_state_dict[new_key] = value |
| 115 | + |
| 116 | + return converted_state_dict |
| 117 | + |
| 118 | + |
| 119 | +def convert_qwen_image_vae(state_dict: dict[str, torch.Tensor]) -> AutoencoderKLQwenImage: |
| 120 | + converted_state_dict = convert_qwen_image_vae_state_dict(state_dict) |
| 121 | + with init_empty_weights(): |
| 122 | + vae = AutoencoderKLQwenImage() |
| 123 | + |
| 124 | + expected_keys = set(vae.state_dict().keys()) |
| 125 | + converted_keys = set(converted_state_dict.keys()) |
| 126 | + missing_keys = expected_keys - converted_keys |
| 127 | + unexpected_keys = converted_keys - expected_keys |
| 128 | + if missing_keys or unexpected_keys: |
| 129 | + if missing_keys: |
| 130 | + print(f"ERROR: missing VAE keys ({len(missing_keys)}):", file=sys.stderr) |
| 131 | + for key in sorted(missing_keys): |
| 132 | + print(key, file=sys.stderr) |
| 133 | + if unexpected_keys: |
| 134 | + print(f"ERROR: unexpected VAE keys ({len(unexpected_keys)}):", file=sys.stderr) |
| 135 | + for key in sorted(unexpected_keys): |
| 136 | + print(key, file=sys.stderr) |
| 137 | + sys.exit(1) |
| 138 | + |
| 139 | + vae.load_state_dict(converted_state_dict, strict=True, assign=True) |
| 140 | + return vae |
| 141 | + |
| 142 | + |
| 143 | +def infer_text_conditioner_config(state_dict: dict[str, torch.Tensor]) -> dict[str, Any]: |
| 144 | + model_dim = state_dict["blocks.0.self_attn.q_proj.weight"].shape[0] |
| 145 | + source_dim = state_dict["blocks.0.cross_attn.k_proj.weight"].shape[1] |
| 146 | + target_vocab_size, target_dim = state_dict["embed.weight"].shape |
| 147 | + attention_head_dim = state_dict["blocks.0.self_attn.q_norm.weight"].shape[0] |
| 148 | + num_layers = 1 + max(int(key.split(".")[1]) for key in state_dict if key.startswith("blocks.")) |
| 149 | + |
| 150 | + return { |
| 151 | + "source_dim": source_dim, |
| 152 | + "target_dim": target_dim, |
| 153 | + "model_dim": model_dim, |
| 154 | + "num_layers": num_layers, |
| 155 | + "num_attention_heads": model_dim // attention_head_dim, |
| 156 | + "target_vocab_size": target_vocab_size, |
| 157 | + } |
| 158 | + |
| 159 | + |
| 160 | +def convert_text_conditioner(state_dict: dict[str, torch.Tensor]) -> AnimaTextConditioner: |
| 161 | + config = infer_text_conditioner_config(state_dict) |
| 162 | + with init_empty_weights(): |
| 163 | + text_conditioner = AnimaTextConditioner(**config) |
| 164 | + |
| 165 | + expected_keys = set(text_conditioner.state_dict().keys()) |
| 166 | + converted_keys = set(state_dict.keys()) |
| 167 | + missing_keys = expected_keys - converted_keys |
| 168 | + unexpected_keys = converted_keys - expected_keys |
| 169 | + if missing_keys or unexpected_keys: |
| 170 | + if missing_keys: |
| 171 | + print(f"ERROR: missing text conditioner keys ({len(missing_keys)}):", file=sys.stderr) |
| 172 | + for key in sorted(missing_keys): |
| 173 | + print(key, file=sys.stderr) |
| 174 | + if unexpected_keys: |
| 175 | + print(f"ERROR: unexpected text conditioner keys ({len(unexpected_keys)}):", file=sys.stderr) |
| 176 | + for key in sorted(unexpected_keys): |
| 177 | + print(key, file=sys.stderr) |
| 178 | + sys.exit(1) |
| 179 | + |
| 180 | + text_conditioner.load_state_dict(state_dict, strict=True, assign=True) |
| 181 | + return text_conditioner |
| 182 | + |
| 183 | + |
| 184 | +def infer_qwen3_config(state_dict: dict[str, torch.Tensor]) -> Qwen3Config: |
| 185 | + vocab_size, hidden_size = state_dict["embed_tokens.weight"].shape |
| 186 | + intermediate_size = state_dict["layers.0.mlp.gate_proj.weight"].shape[0] |
| 187 | + num_hidden_layers = 1 + max(int(key.split(".")[1]) for key in state_dict if key.startswith("layers.")) |
| 188 | + head_dim = state_dict["layers.0.self_attn.q_norm.weight"].shape[0] |
| 189 | + num_attention_heads = state_dict["layers.0.self_attn.q_proj.weight"].shape[0] // head_dim |
| 190 | + num_key_value_heads = state_dict["layers.0.self_attn.k_proj.weight"].shape[0] // head_dim |
| 191 | + |
| 192 | + return Qwen3Config( |
| 193 | + vocab_size=vocab_size, |
| 194 | + hidden_size=hidden_size, |
| 195 | + intermediate_size=intermediate_size, |
| 196 | + num_hidden_layers=num_hidden_layers, |
| 197 | + num_attention_heads=num_attention_heads, |
| 198 | + num_key_value_heads=num_key_value_heads, |
| 199 | + max_position_embeddings=32768, |
| 200 | + rms_norm_eps=1e-6, |
| 201 | + rope_theta=1000000.0, |
| 202 | + head_dim=head_dim, |
| 203 | + attention_bias=False, |
| 204 | + tie_word_embeddings=False, |
| 205 | + ) |
| 206 | + |
| 207 | + |
| 208 | +def convert_text_encoder(state_dict: dict[str, torch.Tensor]) -> Qwen3Model: |
| 209 | + state_dict = {key.removeprefix("model."): value for key, value in state_dict.items()} |
| 210 | + config = infer_qwen3_config(state_dict) |
| 211 | + with init_empty_weights(): |
| 212 | + text_encoder = Qwen3Model(config) |
| 213 | + |
| 214 | + expected_keys = set(text_encoder.state_dict().keys()) |
| 215 | + converted_keys = set(state_dict.keys()) |
| 216 | + missing_keys = expected_keys - converted_keys |
| 217 | + unexpected_keys = converted_keys - expected_keys |
| 218 | + if missing_keys or unexpected_keys: |
| 219 | + if missing_keys: |
| 220 | + print(f"ERROR: missing Qwen3 keys ({len(missing_keys)}):", file=sys.stderr) |
| 221 | + for key in sorted(missing_keys): |
| 222 | + print(key, file=sys.stderr) |
| 223 | + if unexpected_keys: |
| 224 | + print(f"ERROR: unexpected Qwen3 keys ({len(unexpected_keys)}):", file=sys.stderr) |
| 225 | + for key in sorted(unexpected_keys): |
| 226 | + print(key, file=sys.stderr) |
| 227 | + sys.exit(1) |
| 228 | + |
| 229 | + text_encoder.load_state_dict(state_dict, strict=True, assign=True) |
| 230 | + return text_encoder |
| 231 | + |
| 232 | + |
| 233 | +def split_anima_transformer_checkpoint( |
| 234 | + state_dict: dict[str, torch.Tensor], |
| 235 | +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: |
| 236 | + transformer_state_dict = {} |
| 237 | + text_conditioner_state_dict = {} |
| 238 | + adapter_prefix = "net.llm_adapter." |
| 239 | + |
| 240 | + for key, value in state_dict.items(): |
| 241 | + if key.startswith(adapter_prefix): |
| 242 | + text_conditioner_state_dict[key.removeprefix(adapter_prefix)] = value |
| 243 | + else: |
| 244 | + transformer_state_dict[key] = value |
| 245 | + |
| 246 | + return transformer_state_dict, text_conditioner_state_dict |
| 247 | + |
| 248 | + |
| 249 | +def save_pipeline(args, transformer, text_conditioner, text_encoder, vae): |
| 250 | + tokenizer = AutoTokenizer.from_pretrained(args.qwen_tokenizer_path) |
| 251 | + t5_tokenizer = T5TokenizerFast.from_pretrained(args.t5_tokenizer_path) |
| 252 | + scheduler = FlowMatchEulerDiscreteScheduler(shift=3.0) |
| 253 | + |
| 254 | + pipe = AnimaAutoBlocks().init_pipeline() |
| 255 | + pipe.update_components( |
| 256 | + text_encoder=text_encoder, |
| 257 | + tokenizer=tokenizer, |
| 258 | + t5_tokenizer=t5_tokenizer, |
| 259 | + text_conditioner=text_conditioner, |
| 260 | + transformer=transformer, |
| 261 | + vae=vae, |
| 262 | + scheduler=scheduler, |
| 263 | + ) |
| 264 | + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size=args.max_shard_size) |
| 265 | + |
| 266 | + |
| 267 | +def get_args(): |
| 268 | + parser = argparse.ArgumentParser() |
| 269 | + parser.add_argument("--transformer_ckpt_path", type=str, required=True, help="Path to Anima DiT safetensors") |
| 270 | + parser.add_argument("--text_encoder_ckpt_path", type=str, required=True, help="Path to Qwen3 text encoder") |
| 271 | + parser.add_argument("--vae_ckpt_path", type=str, required=True, help="Path to Qwen-Image VAE safetensors") |
| 272 | + parser.add_argument("--qwen_tokenizer_path", type=str, default=None) |
| 273 | + parser.add_argument("--t5_tokenizer_path", type=str, default=None) |
| 274 | + parser.add_argument("--output_path", type=str, required=True) |
| 275 | + parser.add_argument("--save_pipeline", action="store_true") |
| 276 | + parser.add_argument("--dtype", default="bf16", choices=list(DTYPE_MAPPING.keys())) |
| 277 | + parser.add_argument("--max_shard_size", default="5GB") |
| 278 | + return parser.parse_args() |
| 279 | + |
| 280 | + |
| 281 | +if __name__ == "__main__": |
| 282 | + args = get_args() |
| 283 | + output_path = pathlib.Path(args.output_path) |
| 284 | + dtype = DTYPE_MAPPING[args.dtype] |
| 285 | + |
| 286 | + raw_transformer_state_dict = load_file(args.transformer_ckpt_path, device="cpu") |
| 287 | + transformer_state_dict, text_conditioner_state_dict = split_anima_transformer_checkpoint( |
| 288 | + raw_transformer_state_dict |
| 289 | + ) |
| 290 | + transformer = convert_transformer( |
| 291 | + "Cosmos-2.0-Diffusion-2B-Text2Image", state_dict=transformer_state_dict, weights_only=True |
| 292 | + ).to(dtype=dtype) |
| 293 | + text_conditioner = convert_text_conditioner(text_conditioner_state_dict).to(dtype=dtype) |
| 294 | + |
| 295 | + text_encoder_state_dict = load_file(args.text_encoder_ckpt_path, device="cpu") |
| 296 | + text_encoder = convert_text_encoder(text_encoder_state_dict).to(dtype=dtype) |
| 297 | + |
| 298 | + vae_state_dict = load_file(args.vae_ckpt_path, device="cpu") |
| 299 | + vae = convert_qwen_image_vae(vae_state_dict).to(dtype=dtype) |
| 300 | + |
| 301 | + if args.save_pipeline: |
| 302 | + if args.qwen_tokenizer_path is None or args.t5_tokenizer_path is None: |
| 303 | + raise ValueError("`--qwen_tokenizer_path` and `--t5_tokenizer_path` are required with `--save_pipeline`.") |
| 304 | + save_pipeline(args, transformer, text_conditioner, text_encoder, vae) |
| 305 | + else: |
| 306 | + output_path.mkdir(parents=True, exist_ok=True) |
| 307 | + transformer.save_pretrained( |
| 308 | + output_path / "transformer", safe_serialization=True, max_shard_size=args.max_shard_size |
| 309 | + ) |
| 310 | + text_conditioner.save_pretrained( |
| 311 | + output_path / "text_conditioner", safe_serialization=True, max_shard_size=args.max_shard_size |
| 312 | + ) |
| 313 | + text_encoder.save_pretrained( |
| 314 | + output_path / "text_encoder", safe_serialization=True, max_shard_size=args.max_shard_size |
| 315 | + ) |
| 316 | + vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size=args.max_shard_size) |
0 commit comments