Skip to content

Commit b024454

Browse files
authored
Merge branch 'main' into flash-attn-mask
2 parents 8be05cd + d773308 commit b024454

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

tests/models/transformers/test_models_transformer_wan_animate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def get_dummy_inputs(self):
224224
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
225225
return {
226226
"hidden_states": randn_tensor(
227-
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
227+
(1, 36, 5, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype
228228
),
229229
"encoder_hidden_states": randn_tensor(
230230
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
@@ -233,10 +233,10 @@ def get_dummy_inputs(self):
233233
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
234234
),
235235
"pose_hidden_states": randn_tensor(
236-
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
236+
(1, 16, 4, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype
237237
),
238238
"face_pixel_values": randn_tensor(
239-
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
239+
(1, 3, 13, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
240240
),
241241
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
242242
}

0 commit comments

Comments
 (0)