@@ -195,7 +195,7 @@ def get_dummy_inputs(self):
195195 """Override to provide inputs matching the tiny Wan Animate model dimensions."""
196196 return {
197197 "hidden_states" : randn_tensor (
198- (1 , 36 , 21 , 64 , 64 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
198+ (1 , 36 , 5 , 16 , 16 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
199199 ),
200200 "encoder_hidden_states" : randn_tensor (
201201 (1 , 512 , 4096 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
@@ -204,10 +204,10 @@ def get_dummy_inputs(self):
204204 (1 , 257 , 1280 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
205205 ),
206206 "pose_hidden_states" : randn_tensor (
207- (1 , 16 , 20 , 64 , 64 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
207+ (1 , 16 , 4 , 16 , 16 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
208208 ),
209209 "face_pixel_values" : randn_tensor (
210- (1 , 3 , 77 , 512 , 512 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
210+ (1 , 3 , 13 , 512 , 512 ), generator = self .generator , device = torch_device , dtype = self .torch_dtype
211211 ),
212212 "timestep" : torch .tensor ([1.0 ]).to (torch_device , self .torch_dtype ),
213213 }
0 commit comments