Skip to content

Commit c3a4cd1

Browse files
authored
[CI] Refactor Wan Model Tests (#13082)
* update * update * update * update * update * update * update * update
1 parent 4d00980 commit c3a4cd1

File tree

11 files changed

+733
-133
lines changed

11 files changed

+733
-133
lines changed

src/diffusers/models/transformers/transformer_chronoedit.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
4343
encoder_hidden_states = hidden_states
4444

4545
if attn.fused_projections:
46-
if attn.cross_attention_dim_head is None:
46+
if not attn.is_cross_attention:
4747
# In self-attention layers, we can fuse the entire QKV projection into a single linear
4848
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
4949
else:
@@ -219,15 +219,18 @@ def __init__(
219219
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
220220
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
221221

222-
self.is_cross_attention = cross_attention_dim_head is not None
222+
if is_cross_attention is not None:
223+
self.is_cross_attention = is_cross_attention
224+
else:
225+
self.is_cross_attention = cross_attention_dim_head is not None
223226

224227
self.set_processor(processor)
225228

226229
def fuse_projections(self):
227230
if getattr(self, "fused_projections", False):
228231
return
229232

230-
if self.cross_attention_dim_head is None:
233+
if not self.is_cross_attention:
231234
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
232235
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
233236
out_features, in_features = concatenated_weights.shape

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
4242
encoder_hidden_states = hidden_states
4343

4444
if attn.fused_projections:
45-
if attn.cross_attention_dim_head is None:
45+
if not attn.is_cross_attention:
4646
# In self-attention layers, we can fuse the entire QKV projection into a single linear
4747
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
4848
else:
@@ -214,15 +214,18 @@ def __init__(
214214
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
215215
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
216216

217-
self.is_cross_attention = cross_attention_dim_head is not None
217+
if is_cross_attention is not None:
218+
self.is_cross_attention = is_cross_attention
219+
else:
220+
self.is_cross_attention = cross_attention_dim_head is not None
218221

219222
self.set_processor(processor)
220223

221224
def fuse_projections(self):
222225
if getattr(self, "fused_projections", False):
223226
return
224227

225-
if self.cross_attention_dim_head is None:
228+
if not self.is_cross_attention:
226229
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
227230
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
228231
out_features, in_features = concatenated_weights.shape

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
5454
encoder_hidden_states = hidden_states
5555

5656
if attn.fused_projections:
57-
if attn.cross_attention_dim_head is None:
57+
if not attn.is_cross_attention:
5858
# In self-attention layers, we can fuse the entire QKV projection into a single linear
5959
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
6060
else:
@@ -502,24 +502,27 @@ def __init__(
502502
dim_head: int = 64,
503503
eps: float = 1e-6,
504504
cross_attention_dim_head: Optional[int] = None,
505+
bias: bool = True,
505506
processor=None,
506507
):
507508
super().__init__()
508509
self.inner_dim = dim_head * heads
509510
self.heads = heads
510-
self.cross_attention_head_dim = cross_attention_dim_head
511+
self.cross_attention_dim_head = cross_attention_dim_head
511512
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
513+
self.use_bias = bias
514+
self.is_cross_attention = cross_attention_dim_head is not None
512515

513516
# 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector).
514517
# NOTE: this is not used in "vanilla" WanAttention
515518
self.pre_norm_q = nn.LayerNorm(dim, eps, elementwise_affine=False)
516519
self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False)
517520

518521
# 2. QKV and Output Projections
519-
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
520-
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
521-
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
522-
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True)
522+
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=bias)
523+
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
524+
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
525+
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=bias)
523526

524527
# 3. QK Norm
525528
# NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads
@@ -682,15 +685,18 @@ def __init__(
682685
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
683686
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
684687

685-
self.is_cross_attention = cross_attention_dim_head is not None
688+
if is_cross_attention is not None:
689+
self.is_cross_attention = is_cross_attention
690+
else:
691+
self.is_cross_attention = cross_attention_dim_head is not None
686692

687693
self.set_processor(processor)
688694

689695
def fuse_projections(self):
690696
if getattr(self, "fused_projections", False):
691697
return
692698

693-
if self.cross_attention_dim_head is None:
699+
if not self.is_cross_attention:
694700
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
695701
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
696702
out_features, in_features = concatenated_weights.shape

src/diffusers/models/transformers/transformer_wan_vace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
eps=eps,
7777
added_kv_proj_dim=added_kv_proj_dim,
7878
processor=WanAttnProcessor(),
79+
is_cross_attention=True,
7980
)
8081
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
8182

@@ -178,6 +179,7 @@ class WanVACETransformer3DModel(
178179
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
179180
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
180181
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
182+
_repeated_blocks = ["WanTransformerBlock", "WanVACETransformerBlock"]
181183

182184
@register_to_config
183185
def __init__(

src/diffusers/quantizers/gguf/gguf_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, quantization_config, **kwargs):
4141

4242
self.compute_dtype = quantization_config.compute_dtype
4343
self.pre_quantized = quantization_config.pre_quantized
44-
self.modules_to_not_convert = quantization_config.modules_to_not_convert
44+
self.modules_to_not_convert = quantization_config.modules_to_not_convert or []
4545

4646
if not isinstance(self.modules_to_not_convert, list):
4747
self.modules_to_not_convert = [self.modules_to_not_convert]

tests/models/testing_utils/common.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -446,16 +446,17 @@ def test_getattr_is_correct(self, caplog):
446446
torch_device not in ["cuda", "xpu"],
447447
reason="float16 and bfloat16 can only be used with an accelerator",
448448
)
449-
def test_keep_in_fp32_modules(self):
449+
def test_keep_in_fp32_modules(self, tmp_path):
450450
model = self.model_class(**self.get_init_dict())
451451
fp32_modules = model._keep_in_fp32_modules
452452

453453
if fp32_modules is None or len(fp32_modules) == 0:
454454
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
455455

456-
# Test with float16
457-
model.to(torch_device)
458-
model.to(torch.float16)
456+
# Save the model and reload with float16 dtype
457+
# _keep_in_fp32_modules is only enforced during from_pretrained loading
458+
model.save_pretrained(tmp_path)
459+
model = self.model_class.from_pretrained(tmp_path, torch_dtype=torch.float16).to(torch_device)
459460

460461
for name, param in model.named_parameters():
461462
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
@@ -470,7 +471,7 @@ def test_keep_in_fp32_modules(self):
470471
)
471472
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
472473
@torch.no_grad()
473-
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
474+
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, rtol=0):
474475
model = self.model_class(**self.get_init_dict())
475476
model.to(torch_device)
476477
fp32_modules = model._keep_in_fp32_modules or []
@@ -490,10 +491,6 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
490491
output = model(**inputs, return_dict=False)[0]
491492
output_loaded = model_loaded(**inputs, return_dict=False)[0]
492493

493-
self._check_dtype_inference_output(output, output_loaded, dtype)
494-
495-
def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
496-
"""Check dtype inference output with configurable tolerance."""
497494
assert_tensors_close(
498495
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
499496
)

tests/models/testing_utils/quantization.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,7 @@ def _test_quantization_inference(self, config_kwargs):
176176
model_quantized = self._create_quantized_model(config_kwargs)
177177
model_quantized.to(torch_device)
178178

179-
# Get model dtype from first parameter
180-
model_dtype = next(model_quantized.parameters()).dtype
181-
182179
inputs = self.get_dummy_inputs()
183-
# Cast inputs to model dtype
184-
inputs = {
185-
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
186-
for k, v in inputs.items()
187-
}
188180
output = model_quantized(**inputs, return_dict=False)[0]
189181

190182
assert output is not None, "Model output is None"
@@ -229,6 +221,8 @@ def _test_quantization_lora_inference(self, config_kwargs):
229221
init_lora_weights=False,
230222
)
231223
model.add_adapter(lora_config)
224+
# Move LoRA adapter weights to device (they default to CPU)
225+
model.to(torch_device)
232226

233227
inputs = self.get_dummy_inputs()
234228
output = model(**inputs, return_dict=False)[0]
@@ -1021,9 +1015,6 @@ def test_gguf_dequantize(self):
10211015
"""Test that dequantize() works correctly."""
10221016
self._test_dequantize({"compute_dtype": torch.bfloat16})
10231017

1024-
def test_gguf_quantized_layers(self):
1025-
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
1026-
10271018

10281019
@is_quantization
10291020
@is_modelopt

0 commit comments

Comments
 (0)