@@ -224,7 +224,7 @@ def get_dummy_inputs(self):
224224 """Override to provide inputs matching the tiny Wan Animate model dimensions."""
225225 return {
226226 "hidden_states" : randn_tensor (
227- (1 , 36 , 21 , 64 , 64 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
227+ (1 , 36 , 5 , 16 , 16 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
228228 ),
229229 "encoder_hidden_states" : randn_tensor (
230230 (1 , 512 , 4096 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
@@ -233,10 +233,10 @@ def get_dummy_inputs(self):
233233 (1 , 257 , 1280 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
234234 ),
235235 "pose_hidden_states" : randn_tensor (
236- (1 , 16 , 20 , 64 , 64 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
236+ (1 , 16 , 4 , 16 , 16 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
237237 ),
238238 "face_pixel_values" : randn_tensor (
239- (1 , 3 , 77 , 512 , 512 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
239+ (1 , 3 , 13 , 512 , 512 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
240240 ),
241241 "timestep" : torch .tensor ([1.0 ]).to (torch_device , self .torch_dtype ),
242242 }
0 commit comments