@@ -1989,9 +1989,27 @@ def concat_cond(self, **kwargs):
19891989
19901990 latent_dim = self .latent_format .latent_channels
19911991 image = utils .common_upscale (image .to (device ), noise .shape [- 1 ], noise .shape [- 2 ], "bilinear" , "center" )
1992+
1993+ if noise .ndim == 5 and image .ndim == 5 :
1994+ if image .shape [- 3 ] < noise .shape [- 3 ]:
1995+ image = torch .nn .functional .pad (image , (0 , 0 , 0 , 0 , 0 , noise .shape [- 3 ] - image .shape [- 3 ]), "constant" , 0 )
1996+ elif image .shape [- 3 ] > noise .shape [- 3 ]:
1997+ image = image [:, :, :noise .shape [- 3 ]]
1998+
19921999 for i in range (0 , image .shape [1 ], latent_dim ):
19932000 image [:, i :i + latent_dim ] = self .process_latent_in (image [:, i :i + latent_dim ])
19942001 image = utils .resize_to_batch_size (image , noise .shape [0 ])
2002+
2003+ if image .shape [1 ] > extra_channels :
2004+ image = image [:, :extra_channels ]
2005+ elif image .shape [1 ] < extra_channels :
2006+ repeats = extra_channels // image .shape [1 ]
2007+ remainder = extra_channels % image .shape [1 ]
2008+ parts = [image ] * repeats
2009+ if remainder > 0 :
2010+ parts .append (image [:, :remainder ])
2011+ image = torch .cat (parts , dim = 1 )
2012+
19952013 return image
19962014
19972015 def extra_conds (self , ** kwargs ):
0 commit comments