Skip to content

Commit 97ee35f

Browse files
committed
update
1 parent 80ad468 commit 97ee35f

4 files changed

Lines changed: 12 additions & 2 deletions

File tree

src/diffusers/models/transformers/transformer_wan_vace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def forward(
331331
)
332332
if i in self.config.vace_layers:
333333
control_hint, scale = control_hidden_states_list.pop()
334-
hidden_states = hidden_states + control_hint * scale
334+
hidden_states = hidden_states + control_hint.to(hidden_states.device) * scale
335335
else:
336336
# Prepare VACE hints
337337
control_hidden_states_list = []
@@ -346,7 +346,7 @@ def forward(
346346
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
347347
if i in self.config.vace_layers:
348348
control_hint, scale = control_hidden_states_list.pop()
349-
hidden_states = hidden_states + control_hint * scale
349+
hidden_states = hidden_states + control_hint.to(hidden_states.device) * scale
350350

351351
# 6. Output norm, projection & unpatchify
352352
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)

tests/models/transformers/test_models_transformer_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,13 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
9191
(batch_size, num_channels, num_frames, height, width),
9292
generator=self.generator,
9393
device=torch_device,
94+
dtype=self.torch_dtype,
9495
),
9596
"encoder_hidden_states": randn_tensor(
9697
(batch_size, sequence_length, text_encoder_embedding_dim),
9798
generator=self.generator,
9899
device=torch_device,
100+
dtype=self.torch_dtype,
99101
),
100102
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
101103
}

tests/models/transformers/test_models_transformer_wan_animate.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,27 +113,32 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
113113
(batch_size, 2 * num_channels + 4, num_frames + 1, height, width),
114114
generator=self.generator,
115115
device=torch_device,
116+
dtype=self.torch_dtype,
116117
),
117118
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
118119
"encoder_hidden_states": randn_tensor(
119120
(batch_size, sequence_length, text_encoder_embedding_dim),
120121
generator=self.generator,
121122
device=torch_device,
123+
dtype=self.torch_dtype,
122124
),
123125
"encoder_hidden_states_image": randn_tensor(
124126
(batch_size, clip_seq_len, clip_dim),
125127
generator=self.generator,
126128
device=torch_device,
129+
dtype=self.torch_dtype,
127130
),
128131
"pose_hidden_states": randn_tensor(
129132
(batch_size, num_channels, num_frames, height, width),
130133
generator=self.generator,
131134
device=torch_device,
135+
dtype=self.torch_dtype,
132136
),
133137
"face_pixel_values": randn_tensor(
134138
(batch_size, 3, inference_segment_length, face_height, face_width),
135139
generator=self.generator,
136140
device=torch_device,
141+
dtype=self.torch_dtype,
137142
),
138143
}
139144

tests/models/transformers/test_models_transformer_wan_vace.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,19 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
9696
(batch_size, num_channels, num_frames, height, width),
9797
generator=self.generator,
9898
device=torch_device,
99+
dtype=self.torch_dtype,
99100
),
100101
"encoder_hidden_states": randn_tensor(
101102
(batch_size, sequence_length, text_encoder_embedding_dim),
102103
generator=self.generator,
103104
device=torch_device,
105+
dtype=self.torch_dtype,
104106
),
105107
"control_hidden_states": randn_tensor(
106108
(batch_size, vace_in_channels, num_frames, height, width),
107109
generator=self.generator,
108110
device=torch_device,
111+
dtype=self.torch_dtype,
109112
),
110113
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
111114
}

0 commit comments

Comments
 (0)