Skip to content

Commit 9ac1219

Browse files
author
Talmaj Marinc
committed
Fix CogVideoX concat_cond to handle temporal dimension and normalize channel count
1 parent d2489b5 commit 9ac1219

1 file changed

Lines changed: 18 additions & 0 deletions

File tree

comfy/model_base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,9 +1989,27 @@ def concat_cond(self, **kwargs):
19891989

19901990
latent_dim = self.latent_format.latent_channels
19911991
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
1992+
1993+
if noise.ndim == 5 and image.ndim == 5:
1994+
if image.shape[-3] < noise.shape[-3]:
1995+
image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0)
1996+
elif image.shape[-3] > noise.shape[-3]:
1997+
image = image[:, :, :noise.shape[-3]]
1998+
19921999
for i in range(0, image.shape[1], latent_dim):
19932000
image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim])
19942001
image = utils.resize_to_batch_size(image, noise.shape[0])
2002+
2003+
if image.shape[1] > extra_channels:
2004+
image = image[:, :extra_channels]
2005+
elif image.shape[1] < extra_channels:
2006+
repeats = extra_channels // image.shape[1]
2007+
remainder = extra_channels % image.shape[1]
2008+
parts = [image] * repeats
2009+
if remainder > 0:
2010+
parts.append(image[:, :remainder])
2011+
image = torch.cat(parts, dim=1)
2012+
19952013
return image
19962014

19972015
def extra_conds(self, **kwargs):

0 commit comments

Comments
 (0)