@@ -587,10 +587,15 @@ def prepare_latents(
587587 latent_image_ids = latent_image_ids .to (device )
588588
589589 image = image .to (device = device , dtype = dtype )
590- if image .shape [1 ] != self .latent_channels :
590+ if image .shape [1 ] != self .latent_channels * 4 :
591591 image_latents = self ._encode_vae_image (image = image , generator = generator )
592592 else :
593593 image_latents = image
594+ latents_bn_mean = self .vae .bn .running_mean .view (1 , - 1 , 1 , 1 ).to (image_latents .device , image_latents .dtype )
595+ latents_bn_std = torch .sqrt (self .vae .bn .running_var .view (1 , - 1 , 1 , 1 ) + self .vae .config .batch_norm_eps ).to (
596+ image_latents .device , image_latents .dtype
597+ )
598+ image_latents = (image_latents - latents_bn_mean ) / latents_bn_std
594599
595600 if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
596601 # expand init_latents for batch_size
@@ -600,8 +605,6 @@ def prepare_latents(
600605 raise ValueError (
601606 f"Cannot duplicate `image` of batch size { image_latents .shape [0 ]} to { batch_size } text prompts."
602607 )
603- else :
604- image_latents = torch .cat ([image_latents ], dim = 0 )
605608
606609 if latents is None :
607610 noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
@@ -974,11 +977,13 @@ def __call__(
974977
975978 # 2. Preprocess image
976979 multiple_of = self .vae_scale_factor * 2
977- if isinstance (image , torch .Tensor ) and image .ndim == 4 and image .size (1 ) == self .latent_channels :
980+ if isinstance (image , torch .Tensor ) and image .ndim == 4 and image .size (1 ) == self .latent_channels * 4 :
978981 init_image = image
979982 original_image = image
980983 crops_coords = None
981984 resize_mode = "default"
985+ height = image .shape [2 ] * self .vae_scale_factor * 2
986+ width = image .shape [3 ] * self .vae_scale_factor * 2
982987 elif image is not None :
983988 if isinstance (image , list ) and isinstance (image [0 ], torch .Tensor ) and image [0 ].ndim == 4 :
984989 image = torch .cat (image , dim = 0 )
@@ -1011,12 +1016,10 @@ def __call__(
10111016 image , image_height , image_width , crops_coords = crops_coords , resize_mode = resize_mode
10121017 )
10131018
1014- init_image = init_image .to (dtype = torch .float32 )
1015-
10161019 # 2.2 Preprocess reference image
10171020 processed_image_reference = None
10181021 if image_reference is not None and not (
1019- isinstance (image_reference , torch .Tensor ) and image_reference .size (1 ) == self .latent_channels
1022+ isinstance (image_reference , torch .Tensor ) and image_reference .size (1 ) == self .latent_channels * 4
10201023 ):
10211024 if (
10221025 isinstance (image_reference , list )
@@ -1045,7 +1048,13 @@ def __call__(
10451048 image_reference_width ,
10461049 resize_mode = "crop" ,
10471050 )
1048- processed_image_reference = processed_image_reference .to (dtype = torch .float32 )
1051+ else :
1052+ if image_reference is not None :
1053+ bn_mean = self .vae .bn .running_mean .view (1 , - 1 , 1 , 1 ).to (image_reference .device , image_reference .dtype )
1054+ bn_std = torch .sqrt (self .vae .bn .running_var .view (1 , - 1 , 1 , 1 ) + self .vae .config .batch_norm_eps ).to (
1055+ image_reference .device , image_reference .dtype
1056+ )
1057+ processed_image_reference = (image_reference - bn_mean ) / bn_std
10491058
10501059 # 3. Define call parameters
10511060 if prompt is not None and isinstance (prompt , str ):
0 commit comments