@@ -157,12 +157,14 @@ def _get_positional_embeddings(self, sample_height, sample_width, sample_frames,
157157 return joint_pos_embedding
158158
159159 def forward (self , text_embeds , image_embeds ):
160- text_embeds = self .text_proj (text_embeds )
160+ input_dtype = text_embeds .dtype
161+ text_embeds = self .text_proj (text_embeds .to (self .text_proj .weight .dtype )).to (input_dtype )
161162 batch_size , num_frames , channels , height , width = image_embeds .shape
162163
164+ proj_dtype = self .proj .weight .dtype
163165 if self .patch_size_t is None :
164166 image_embeds = image_embeds .reshape (- 1 , channels , height , width )
165- image_embeds = self .proj (image_embeds )
167+ image_embeds = self .proj (image_embeds . to ( proj_dtype )). to ( input_dtype )
166168 image_embeds = image_embeds .view (batch_size , num_frames , * image_embeds .shape [1 :])
167169 image_embeds = image_embeds .flatten (3 ).transpose (2 , 3 )
168170 image_embeds = image_embeds .flatten (1 , 2 )
@@ -174,7 +176,7 @@ def forward(self, text_embeds, image_embeds):
174176 batch_size , num_frames // p_t , p_t , height // p , p , width // p , p , channels
175177 )
176178 image_embeds = image_embeds .permute (0 , 1 , 3 , 5 , 7 , 2 , 4 , 6 ).flatten (4 , 7 ).flatten (1 , 3 )
177- image_embeds = self .proj (image_embeds )
179+ image_embeds = self .proj (image_embeds . to ( proj_dtype )). to ( input_dtype )
178180
179181 embeds = torch .cat ([text_embeds , image_embeds ], dim = 1 ).contiguous ()
180182
@@ -378,7 +380,7 @@ def __init__(self,
378380 temporal_interpolation_scale = temporal_interpolation_scale ,
379381 use_positional_embeddings = not use_rotary_positional_embeddings ,
380382 use_learned_positional_embeddings = use_learned_positional_embeddings ,
381- device = device , dtype = dtype , operations = operations ,
383+ device = device , dtype = torch . float32 , operations = operations ,
382384 )
383385
384386 # 2. Time embedding
0 commit comments