Skip to content

Commit a362223

Browse files
Fix OOM in WanAnimate BitsAndBytes Training Test (#13777)
reduce input size for tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent f502538 commit a362223

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

tests/models/transformers/test_models_transformer_wan_animate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def get_dummy_inputs(self):
195195
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
196196
return {
197197
"hidden_states": randn_tensor(
198-
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
198+
(1, 36, 5, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype
199199
),
200200
"encoder_hidden_states": randn_tensor(
201201
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
@@ -204,10 +204,10 @@ def get_dummy_inputs(self):
204204
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
205205
),
206206
"pose_hidden_states": randn_tensor(
207-
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
207+
(1, 16, 4, 16, 16), generator=self.generator, device=torch_device, dtype=self.torch_dtype
208208
),
209209
"face_pixel_values": randn_tensor(
210-
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
210+
(1, 3, 13, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
211211
),
212212
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
213213
}

0 commit comments

Comments
 (0)