@@ -113,27 +113,32 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
113113 (batch_size , 2 * num_channels + 4 , num_frames + 1 , height , width ),
114114 generator = self .generator ,
115115 device = torch_device ,
116+ dtype = self .torch_dtype ,
116117 ),
117118 "timestep" : torch .randint (0 , 1000 , size = (batch_size ,), generator = self .generator ).to (torch_device ),
118119 "encoder_hidden_states" : randn_tensor (
119120 (batch_size , sequence_length , text_encoder_embedding_dim ),
120121 generator = self .generator ,
121122 device = torch_device ,
123+ dtype = self .torch_dtype ,
122124 ),
123125 "encoder_hidden_states_image" : randn_tensor (
124126 (batch_size , clip_seq_len , clip_dim ),
125127 generator = self .generator ,
126128 device = torch_device ,
129+ dtype = self .torch_dtype ,
127130 ),
128131 "pose_hidden_states" : randn_tensor (
129132 (batch_size , num_channels , num_frames , height , width ),
130133 generator = self .generator ,
131134 device = torch_device ,
135+ dtype = self .torch_dtype ,
132136 ),
133137 "face_pixel_values" : randn_tensor (
134138 (batch_size , 3 , inference_segment_length , face_height , face_width ),
135139 generator = self .generator ,
136140 device = torch_device ,
141+ dtype = self .torch_dtype ,
137142 ),
138143 }
139144
0 commit comments