136136 "wan" : ["model.diffusion_model.head.modulation" , "head.modulation" ],
137137 "wan_vae" : "decoder.middle.0.residual.0.gamma" ,
138138 "wan_vace" : "vace_blocks.0.after_proj.bias" ,
139+ "wan_animate" : "motion_encoder.dec.direction.weight" ,
139140 "hidream" : "double_stream_blocks.0.block.adaLN_modulation.1.bias" ,
140141 "cosmos-1.0" : [
141142 "net.x_embedder.proj.1.weight" ,
219220 "wan-t2v-1.3B" : {"pretrained_model_name_or_path" : "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" },
220221 "wan-t2v-14B" : {"pretrained_model_name_or_path" : "Wan-AI/Wan2.1-T2V-14B-Diffusers" },
221222 "wan-i2v-14B" : {"pretrained_model_name_or_path" : "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" },
223+ "wan-animate-14B" : {"pretrained_model_name_or_path" : "Wan-AI/Wan2.2-Animate-14B-Diffusers" },
222224 "wan-vace-1.3B" : {"pretrained_model_name_or_path" : "Wan-AI/Wan2.1-VACE-1.3B-diffusers" },
223225 "wan-vace-14B" : {"pretrained_model_name_or_path" : "Wan-AI/Wan2.1-VACE-14B-diffusers" },
224226 "hidream" : {"pretrained_model_name_or_path" : "HiDream-ai/HiDream-I1-Dev" },
@@ -759,6 +761,9 @@ def infer_diffusers_model_type(checkpoint):
759761 elif checkpoint [target_key ].shape [0 ] == 5120 :
760762 model_type = "wan-vace-14B"
761763
764+ if CHECKPOINT_KEY_NAMES ["wan_animate" ] in checkpoint :
765+ model_type = "wan-animate-14B"
766+
762767 elif checkpoint [target_key ].shape [0 ] == 1536 :
763768 model_type = "wan-t2v-1.3B"
764769 elif checkpoint [target_key ].shape [0 ] == 5120 and checkpoint [target_key ].shape [1 ] == 16 :
@@ -3154,13 +3159,64 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
31543159
31553160
31563161def convert_wan_transformer_to_diffusers (checkpoint , ** kwargs ):
3162+ def generate_motion_encoder_mappings ():
3163+ mappings = {
3164+ "motion_encoder.dec.direction.weight" : "motion_encoder.motion_synthesis_weight" ,
3165+ "motion_encoder.enc.net_app.convs.0.0.weight" : "motion_encoder.conv_in.weight" ,
3166+ "motion_encoder.enc.net_app.convs.0.1.bias" : "motion_encoder.conv_in.act_fn.bias" ,
3167+ "motion_encoder.enc.net_app.convs.8.weight" : "motion_encoder.conv_out.weight" ,
3168+ "motion_encoder.enc.fc" : "motion_encoder.motion_network" ,
3169+ }
3170+
3171+ for i in range (7 ):
3172+ conv_idx = i + 1
3173+ mappings .update (
3174+ {
3175+ f"motion_encoder.enc.net_app.convs.{ conv_idx } .conv1.0.weight" : f"motion_encoder.res_blocks.{ i } .conv1.weight" ,
3176+ f"motion_encoder.enc.net_app.convs.{ conv_idx } .conv1.1.bias" : f"motion_encoder.res_blocks.{ i } .conv1.act_fn.bias" ,
3177+ f"motion_encoder.enc.net_app.convs.{ conv_idx } .conv2.1.weight" : f"motion_encoder.res_blocks.{ i } .conv2.weight" ,
3178+ f"motion_encoder.enc.net_app.convs.{ conv_idx } .conv2.2.bias" : f"motion_encoder.res_blocks.{ i } .conv2.act_fn.bias" ,
3179+ f"motion_encoder.enc.net_app.convs.{ conv_idx } .skip.1.weight" : f"motion_encoder.res_blocks.{ i } .conv_skip.weight" ,
3180+ }
3181+ )
3182+
3183+ return mappings
3184+
3185+ def generate_face_adapter_mappings ():
3186+ return {
3187+ "face_adapter.fuser_blocks" : "face_adapter" ,
3188+ ".k_norm." : ".norm_k." ,
3189+ ".q_norm." : ".norm_q." ,
3190+ ".linear1_q." : ".to_q." ,
3191+ ".linear2." : ".to_out." ,
3192+ "conv1_local.conv" : "conv1_local" ,
3193+ "conv2.conv" : "conv2" ,
3194+ "conv3.conv" : "conv3" ,
3195+ }
3196+
3197+ def split_tensor_handler (key , state_dict , split_pattern , target_keys ):
3198+ tensor = state_dict .pop (key )
3199+ split_idx = tensor .shape [0 ] // 2
3200+
3201+ new_key_1 = key .replace (split_pattern , target_keys [0 ])
3202+ new_key_2 = key .replace (split_pattern , target_keys [1 ])
3203+
3204+ state_dict [new_key_1 ] = tensor [:split_idx ]
3205+ state_dict [new_key_2 ] = tensor [split_idx :]
3206+
3207+ def reshape_bias_handler (key , state_dict ):
3208+ if "motion_encoder.enc.net_app.convs." in key and ".bias" in key :
3209+ state_dict [key ] = state_dict [key ][0 , :, 0 , 0 ]
3210+
31573211 converted_state_dict = {}
31583212
3213+ # Strip model.diffusion_model prefix
31593214 keys = list (checkpoint .keys ())
31603215 for k in keys :
31613216 if "model.diffusion_model." in k :
31623217 checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
31633218
3219+ # Base transformer mappings
31643220 TRANSFORMER_KEYS_RENAME_DICT = {
31653221 "time_embedding.0" : "condition_embedder.time_embedder.linear_1" ,
31663222 "time_embedding.2" : "condition_embedder.time_embedder.linear_2" ,
@@ -3182,28 +3238,43 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
31823238 "ffn.0" : "ffn.net.0.proj" ,
31833239 "ffn.2" : "ffn.net.2" ,
31843240 # Hack to swap the layer names
3185- # The original model calls the norms in following order: norm1, norm3, norm2
3186- # We convert it to: norm1, norm2, norm3
31873241 "norm2" : "norm__placeholder" ,
31883242 "norm3" : "norm2" ,
31893243 "norm__placeholder" : "norm3" ,
3190- # For the I2V model
3244+ # I2V model
31913245 "img_emb.proj.0" : "condition_embedder.image_embedder.norm1" ,
31923246 "img_emb.proj.1" : "condition_embedder.image_embedder.ff.net.0.proj" ,
31933247 "img_emb.proj.3" : "condition_embedder.image_embedder.ff.net.2" ,
31943248 "img_emb.proj.4" : "condition_embedder.image_embedder.norm2" ,
3195- # For the VACE model
3249+ # VACE model
31963250 "before_proj" : "proj_in" ,
31973251 "after_proj" : "proj_out" ,
31983252 }
31993253
3254+ SPECIAL_KEYS_HANDLERS = {}
3255+ if any ("face_adapter" in k for k in checkpoint .keys ()):
3256+ TRANSFORMER_KEYS_RENAME_DICT .update (generate_face_adapter_mappings ())
3257+ SPECIAL_KEYS_HANDLERS [".linear1_kv." ] = (split_tensor_handler , [".to_k." , ".to_v." ])
3258+
3259+ if any ("motion_encoder" in k for k in checkpoint .keys ()):
3260+ TRANSFORMER_KEYS_RENAME_DICT .update (generate_motion_encoder_mappings ())
3261+
32003262 for key in list (checkpoint .keys ()):
3201- new_key = key [:]
3263+ reshape_bias_handler (key , checkpoint )
3264+
3265+ for key in list (checkpoint .keys ()):
3266+ new_key = key
32023267 for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
32033268 new_key = new_key .replace (replace_key , rename_key )
3204-
32053269 converted_state_dict [new_key ] = checkpoint .pop (key )
32063270
3271+ for key in list (converted_state_dict .keys ()):
3272+ for pattern , (handler_fn , target_keys ) in SPECIAL_KEYS_HANDLERS .items ():
3273+ if pattern not in key :
3274+ continue
3275+ handler_fn (key , converted_state_dict , pattern , target_keys )
3276+ break
3277+
32073278 return converted_state_dict
32083279
32093280
0 commit comments