Skip to content

Commit 53279ef

Browse files
samadwarsamedwardsFMgithub-actions[bot]
authored
[From Single File] support from_single_file method for WanAnimateTransformer3DModel (#12691)
* Add `WanAnimateTransformer3DModel` to `SINGLE_FILE_LOADABLE_CLASSES` * Fixed dtype mismatch when loading a single file * Fixed a bug that results in white noise for generation * Update dtype check for time embedder - caused white noise output * Improve code readability * Optimize dtype handling Removed unnecessary dtype conversions for timestep and weight. * Apply style fixes * Refactor time embedding dtype handling Adjust time embedding type conversion for compatibility. * Apply style fixes * Modify comment for WanTimeTextImageEmbedding class --------- Co-authored-by: Sam Edwards <sam.edwards1976@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent d9959bd commit 53279ef

File tree

3 files changed

+91
-13
lines changed

3 files changed

+91
-13
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@
152152
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
153153
"default_subfolder": "transformer",
154154
},
155+
"WanAnimateTransformer3DModel": {
156+
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
157+
"default_subfolder": "transformer",
158+
},
155159
"AutoencoderKLWan": {
156160
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
157161
"default_subfolder": "vae",

src/diffusers/loaders/single_file_utils.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
137137
"wan_vae": "decoder.middle.0.residual.0.gamma",
138138
"wan_vace": "vace_blocks.0.after_proj.bias",
139+
"wan_animate": "motion_encoder.dec.direction.weight",
139140
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
140141
"cosmos-1.0": [
141142
"net.x_embedder.proj.1.weight",
@@ -219,6 +220,7 @@
219220
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
220221
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
221222
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
223+
"wan-animate-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.2-Animate-14B-Diffusers"},
222224
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
223225
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
224226
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
@@ -759,6 +761,9 @@ def infer_diffusers_model_type(checkpoint):
759761
elif checkpoint[target_key].shape[0] == 5120:
760762
model_type = "wan-vace-14B"
761763

764+
if CHECKPOINT_KEY_NAMES["wan_animate"] in checkpoint:
765+
model_type = "wan-animate-14B"
766+
762767
elif checkpoint[target_key].shape[0] == 1536:
763768
model_type = "wan-t2v-1.3B"
764769
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
@@ -3154,13 +3159,64 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
31543159

31553160

31563161
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
3162+
def generate_motion_encoder_mappings():
3163+
mappings = {
3164+
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
3165+
"motion_encoder.enc.net_app.convs.0.0.weight": "motion_encoder.conv_in.weight",
3166+
"motion_encoder.enc.net_app.convs.0.1.bias": "motion_encoder.conv_in.act_fn.bias",
3167+
"motion_encoder.enc.net_app.convs.8.weight": "motion_encoder.conv_out.weight",
3168+
"motion_encoder.enc.fc": "motion_encoder.motion_network",
3169+
}
3170+
3171+
for i in range(7):
3172+
conv_idx = i + 1
3173+
mappings.update(
3174+
{
3175+
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.0.weight": f"motion_encoder.res_blocks.{i}.conv1.weight",
3176+
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.1.bias": f"motion_encoder.res_blocks.{i}.conv1.act_fn.bias",
3177+
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.1.weight": f"motion_encoder.res_blocks.{i}.conv2.weight",
3178+
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.2.bias": f"motion_encoder.res_blocks.{i}.conv2.act_fn.bias",
3179+
f"motion_encoder.enc.net_app.convs.{conv_idx}.skip.1.weight": f"motion_encoder.res_blocks.{i}.conv_skip.weight",
3180+
}
3181+
)
3182+
3183+
return mappings
3184+
3185+
def generate_face_adapter_mappings():
3186+
return {
3187+
"face_adapter.fuser_blocks": "face_adapter",
3188+
".k_norm.": ".norm_k.",
3189+
".q_norm.": ".norm_q.",
3190+
".linear1_q.": ".to_q.",
3191+
".linear2.": ".to_out.",
3192+
"conv1_local.conv": "conv1_local",
3193+
"conv2.conv": "conv2",
3194+
"conv3.conv": "conv3",
3195+
}
3196+
3197+
def split_tensor_handler(key, state_dict, split_pattern, target_keys):
3198+
tensor = state_dict.pop(key)
3199+
split_idx = tensor.shape[0] // 2
3200+
3201+
new_key_1 = key.replace(split_pattern, target_keys[0])
3202+
new_key_2 = key.replace(split_pattern, target_keys[1])
3203+
3204+
state_dict[new_key_1] = tensor[:split_idx]
3205+
state_dict[new_key_2] = tensor[split_idx:]
3206+
3207+
def reshape_bias_handler(key, state_dict):
3208+
if "motion_encoder.enc.net_app.convs." in key and ".bias" in key:
3209+
state_dict[key] = state_dict[key][0, :, 0, 0]
3210+
31573211
converted_state_dict = {}
31583212

3213+
# Strip model.diffusion_model prefix
31593214
keys = list(checkpoint.keys())
31603215
for k in keys:
31613216
if "model.diffusion_model." in k:
31623217
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
31633218

3219+
# Base transformer mappings
31643220
TRANSFORMER_KEYS_RENAME_DICT = {
31653221
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
31663222
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
@@ -3182,28 +3238,43 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
31823238
"ffn.0": "ffn.net.0.proj",
31833239
"ffn.2": "ffn.net.2",
31843240
# Hack to swap the layer names
3185-
# The original model calls the norms in following order: norm1, norm3, norm2
3186-
# We convert it to: norm1, norm2, norm3
31873241
"norm2": "norm__placeholder",
31883242
"norm3": "norm2",
31893243
"norm__placeholder": "norm3",
3190-
# For the I2V model
3244+
# I2V model
31913245
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
31923246
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
31933247
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
31943248
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
3195-
# For the VACE model
3249+
# VACE model
31963250
"before_proj": "proj_in",
31973251
"after_proj": "proj_out",
31983252
}
31993253

3254+
SPECIAL_KEYS_HANDLERS = {}
3255+
if any("face_adapter" in k for k in checkpoint.keys()):
3256+
TRANSFORMER_KEYS_RENAME_DICT.update(generate_face_adapter_mappings())
3257+
SPECIAL_KEYS_HANDLERS[".linear1_kv."] = (split_tensor_handler, [".to_k.", ".to_v."])
3258+
3259+
if any("motion_encoder" in k for k in checkpoint.keys()):
3260+
TRANSFORMER_KEYS_RENAME_DICT.update(generate_motion_encoder_mappings())
3261+
32003262
for key in list(checkpoint.keys()):
3201-
new_key = key[:]
3263+
reshape_bias_handler(key, checkpoint)
3264+
3265+
for key in list(checkpoint.keys()):
3266+
new_key = key
32023267
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
32033268
new_key = new_key.replace(replace_key, rename_key)
3204-
32053269
converted_state_dict[new_key] = checkpoint.pop(key)
32063270

3271+
for key in list(converted_state_dict.keys()):
3272+
for pattern, (handler_fn, target_keys) in SPECIAL_KEYS_HANDLERS.items():
3273+
if pattern not in key:
3274+
continue
3275+
handler_fn(key, converted_state_dict, pattern, target_keys)
3276+
break
3277+
32073278
return converted_state_dict
32083279

32093280

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,11 @@ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
166166
# NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates
167167
# set to 1, which should be equivalent to a 2D convolution
168168
expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1)
169+
x = x.to(expanded_kernel.dtype)
169170
x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels)
170171

171172
# Main Conv2D with scaling
173+
x = x.to(self.weight.dtype)
172174
x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
173175

174176
# Activation with fused bias, if using
@@ -338,8 +340,7 @@ def forward(self, face_image: torch.Tensor, channel_dim: int = 1) -> torch.Tenso
338340
weight = self.motion_synthesis_weight + 1e-8
339341
# Upcast the QR orthogonalization operation to FP32
340342
original_motion_dtype = motion_feat.dtype
341-
motion_feat = motion_feat.to(torch.float32)
342-
weight = weight.to(torch.float32)
343+
motion_feat = motion_feat.to(weight.dtype)
343344

344345
Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device)
345346

@@ -769,7 +770,7 @@ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
769770
return hidden_states
770771

771772

772-
# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
773+
# Modified from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
773774
class WanTimeTextImageEmbedding(nn.Module):
774775
def __init__(
775776
self,
@@ -803,10 +804,12 @@ def forward(
803804
if timestep_seq_len is not None:
804805
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
805806

806-
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
807-
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
808-
timestep = timestep.to(time_embedder_dtype)
809-
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
807+
if self.time_embedder.linear_1.weight.dtype.is_floating_point:
808+
time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
809+
else:
810+
time_embedder_dtype = encoder_hidden_states.dtype
811+
812+
temb = self.time_embedder(timestep.to(time_embedder_dtype)).type_as(encoder_hidden_states)
810813
timestep_proj = self.time_proj(self.act_fn(temb))
811814

812815
encoder_hidden_states = self.text_embedder(encoder_hidden_states)

0 commit comments

Comments
 (0)