Skip to content

Commit b88e60b

Browse files
Fix: ensure consistent dtype and eval mode in pipeline save/load tests (#13339)
* Fix: ensure consistent dtype and eval mode in pipeline save/load tests * Modify according to the comments * Update according to the comments * Update comment * Code quality * cast buffers to torch.float16 * conflict * Fix --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 7e463ea commit b88e60b

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,7 @@ class WanAnimateTransformer3DModel(
10291029
"norm2",
10301030
"norm3",
10311031
"motion_synthesis_weight",
1032+
"rope",
10321033
]
10331034
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
10341035
_repeated_blocks = ["WanTransformerBlock"]

tests/pipelines/test_pipelines_common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,10 +1443,24 @@ def test_save_load_float16(self, expected_max_diff=1e-2):
14431443
param.data = param.data.to(torch_device).to(torch.float32)
14441444
else:
14451445
param.data = param.data.to(torch_device).to(torch.float16)
1446+
for name, buf in module.named_buffers():
1447+
if not buf.is_floating_point():
1448+
buf.data = buf.data.to(torch_device)
1449+
elif any(
1450+
module_to_keep_in_fp32 in name.split(".")
1451+
for module_to_keep_in_fp32 in module._keep_in_fp32_modules
1452+
):
1453+
buf.data = buf.data.to(torch_device).to(torch.float32)
1454+
else:
1455+
buf.data = buf.data.to(torch_device).to(torch.float16)
14461456

14471457
elif hasattr(module, "half"):
14481458
components[name] = module.to(torch_device).half()
14491459

1460+
for key, component in components.items():
1461+
if hasattr(component, "eval"):
1462+
component.eval()
1463+
14501464
pipe = self.pipeline_class(**components)
14511465
for component in pipe.components.values():
14521466
if hasattr(component, "set_default_attn_processor"):

0 commit comments

Comments
 (0)