Skip to content

Commit 0399bd3

Browse files
author
Talmaj Marinc
committed
Revert dtype to float32 to increase quality of video output.
1 parent b59b8d5 commit 0399bd3

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

comfy/ldm/cogvideo/model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,14 @@ def _get_positional_embeddings(self, sample_height, sample_width, sample_frames,
157157
return joint_pos_embedding
158158

159159
def forward(self, text_embeds, image_embeds):
160-
text_embeds = self.text_proj(text_embeds)
160+
input_dtype = text_embeds.dtype
161+
text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype)
161162
batch_size, num_frames, channels, height, width = image_embeds.shape
162163

164+
proj_dtype = self.proj.weight.dtype
163165
if self.patch_size_t is None:
164166
image_embeds = image_embeds.reshape(-1, channels, height, width)
165-
image_embeds = self.proj(image_embeds)
167+
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
166168
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
167169
image_embeds = image_embeds.flatten(3).transpose(2, 3)
168170
image_embeds = image_embeds.flatten(1, 2)
@@ -174,7 +176,7 @@ def forward(self, text_embeds, image_embeds):
174176
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
175177
)
176178
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
177-
image_embeds = self.proj(image_embeds)
179+
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
178180

179181
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
180182

@@ -378,7 +380,7 @@ def __init__(self,
378380
temporal_interpolation_scale=temporal_interpolation_scale,
379381
use_positional_embeddings=not use_rotary_positional_embeddings,
380382
use_learned_positional_embeddings=use_learned_positional_embeddings,
381-
device=device, dtype=dtype, operations=operations,
383+
device=device, dtype=torch.float32, operations=operations,
382384
)
383385

384386
# 2. Time embedding

0 commit comments

Comments
 (0)