@@ -236,21 +236,11 @@ def __init__(
236236 self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
237237 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial , resample = "bilinear" )
238238
239- assert getattr (self .vae .config , "latents_mean" , None ), "VAE configuration must define `latents_mean`."
240- assert getattr (self .vae .config , "latents_std" , None ), "VAE configuration must define `latents_std`."
241-
242239 latents_mean = torch .tensor (self .vae .config .latents_mean ).view (1 , self .vae .config .z_dim , 1 , 1 , 1 ).float ()
243240 latents_std = torch .tensor (self .vae .config .latents_std ).view (1 , self .vae .config .z_dim , 1 , 1 , 1 ).float ()
244241 self .latents_mean = latents_mean
245242 self .latents_std = 1.0 / latents_std
246243
247- def get_latent_shape_cthw (self , height : int , width : int , num_frames : int ):
248- C = self .vae .config .z_dim
249- T = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
250- H = height // self .vae_scale_factor_spatial
251- W = width // self .vae_scale_factor_spatial
252- return (C , T , H , W )
253-
254244 def create_condition_mask (self , latent_shape , device , dtype , num_cond_latent_frames ):
255245 bsz , C , T , H , W = latent_shape
256246 cond_indicator = torch .zeros (bsz , 1 , T , 1 , 1 , dtype = dtype , device = device )
@@ -438,9 +428,11 @@ def prepare_latents(
438428 )
439429
440430 B = batch_size
441- C , T , H , W = self .get_latent_shape_cthw (height , width , num_frames_out )
431+ C = num_channels_latents
432+ T = (num_frames_out - 1 ) // self .vae_scale_factor_temporal + 1
433+ H = height // self .vae_scale_factor_spatial
434+ W = width // self .vae_scale_factor_spatial
442435 shape = (B , C , T , H , W )
443- assert C == num_channels_latents , f"Expected number of channels to be { num_channels_latents } , but got { C } ."
444436
445437 if num_frames_in == 0 :
446438 if latents is None :
0 commit comments