Skip to content

Commit e5f4b23

Browse files
committed
Reduce WanAnimate TorchAO test input sizes to prevent OOM
Shrink dummy inputs to avoid OOM on devices without FlashAttention. Reduce hidden_states spatial from 64x64 to 16x16 and frames from 21 to 5, bringing self-attention sequence length from 21,504 to 320.
1 parent c8c8401 commit e5f4b23

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

tests/models/transformers/test_models_transformer_wan_animate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def get_dummy_inputs(self):
219219
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
220220
return {
221221
"hidden_states": randn_tensor(
222-
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
222+
(1, 36, 5, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype
223223
),
224224
"encoder_hidden_states": randn_tensor(
225225
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
@@ -228,10 +228,10 @@ def get_dummy_inputs(self):
228228
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
229229
),
230230
"pose_hidden_states": randn_tensor(
231-
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
231+
(1, 16, 4, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype
232232
),
233233
"face_pixel_values": randn_tensor(
234-
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
234+
(1, 3, 13, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
235235
),
236236
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
237237
}

0 commit comments

Comments
 (0)