@@ -912,6 +912,12 @@ def prepare_latents(
912912 f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is { type (image )} "
913913 )
914914
915+ latents_mean = latents_std = None
916+ if hasattr (self .vae .config , "latents_mean" ) and self .vae .config .latents_mean is not None :
917+ latents_mean = torch .tensor (self .vae .config .latents_mean ).view (1 , 4 , 1 , 1 )
918+ if hasattr (self .vae .config , "latents_std" ) and self .vae .config .latents_std is not None :
919+ latents_std = torch .tensor (self .vae .config .latents_std ).view (1 , 4 , 1 , 1 )
920+
915921 # Offload text encoder if `enable_model_cpu_offload` was enabled
916922 if hasattr (self , "final_offload_hook" ) and self .final_offload_hook is not None :
917923 self .text_encoder_2 .to ("cpu" )
@@ -925,11 +931,6 @@ def prepare_latents(
925931 init_latents = image
926932
927933 else :
928- latents_mean = latents_std = None
929- if hasattr (self .vae .config , "latents_mean" ) and self .vae .config .latents_mean is not None :
930- latents_mean = torch .tensor (self .vae .config .latents_mean ).view (1 , 4 , 1 , 1 )
931- if hasattr (self .vae .config , "latents_std" ) and self .vae .config .latents_std is not None :
932- latents_std = torch .tensor (self .vae .config .latents_std ).view (1 , 4 , 1 , 1 )
933934 # make sure the VAE is in float32 mode, as it overflows in float16
934935 if self .vae .config .force_upcast :
935936 image = image .float ()
0 commit comments