Skip to content

Commit 41d8a98

Browse files
committed
fixed pre-encoded latent preprocessing for source and ref images
1 parent 2d83f13 commit 41d8a98

2 files changed

Lines changed: 18 additions & 9 deletions

File tree

src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -587,10 +587,15 @@ def prepare_latents(
587587
latent_image_ids = latent_image_ids.to(device)
588588

589589
image = image.to(device=device, dtype=dtype)
590-
if image.shape[1] != self.latent_channels:
590+
if image.shape[1] != self.latent_channels * 4:
591591
image_latents = self._encode_vae_image(image=image, generator=generator)
592592
else:
593593
image_latents = image
594+
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
595+
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
596+
image_latents.device, image_latents.dtype
597+
)
598+
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
594599

595600
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
596601
# expand init_latents for batch_size
@@ -600,8 +605,6 @@ def prepare_latents(
600605
raise ValueError(
601606
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
602607
)
603-
else:
604-
image_latents = torch.cat([image_latents], dim=0)
605608

606609
if latents is None:
607610
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -974,11 +977,13 @@ def __call__(
974977

975978
# 2. Preprocess image
976979
multiple_of = self.vae_scale_factor * 2
977-
if isinstance(image, torch.Tensor) and image.ndim == 4 and image.size(1) == self.latent_channels:
980+
if isinstance(image, torch.Tensor) and image.ndim == 4 and image.size(1) == self.latent_channels * 4:
978981
init_image = image
979982
original_image = image
980983
crops_coords = None
981984
resize_mode = "default"
985+
height = image.shape[2] * self.vae_scale_factor * 2
986+
width = image.shape[3] * self.vae_scale_factor * 2
982987
elif image is not None:
983988
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
984989
image = torch.cat(image, dim=0)
@@ -1011,12 +1016,10 @@ def __call__(
10111016
image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode
10121017
)
10131018

1014-
init_image = init_image.to(dtype=torch.float32)
1015-
10161019
# 2.2 Preprocess reference image
10171020
processed_image_reference = None
10181021
if image_reference is not None and not (
1019-
isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels
1022+
isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels * 4
10201023
):
10211024
if (
10221025
isinstance(image_reference, list)
@@ -1045,7 +1048,13 @@ def __call__(
10451048
image_reference_width,
10461049
resize_mode="crop",
10471050
)
1048-
processed_image_reference = processed_image_reference.to(dtype=torch.float32)
1051+
else:
1052+
if image_reference is not None:
1053+
bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_reference.device, image_reference.dtype)
1054+
bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
1055+
image_reference.device, image_reference.dtype
1056+
)
1057+
processed_image_reference = (image_reference - bn_mean) / bn_std
10491058

10501059
# 3. Define call parameters
10511060
if prompt is not None and isinstance(prompt, str):

tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Flux2KleinInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
2828
params = frozenset(
2929
["prompt", "image", "image_reference", "mask_image", "height", "width", "guidance_scale", "prompt_embeds"]
3030
)
31-
batch_params = frozenset(["prompt", "image", "mask_image"])
31+
batch_params = frozenset(["prompt", "image", "image_reference", "mask_image"])
3232

3333
test_xformers_attention = False
3434
test_layerwise_casting = True

0 commit comments

Comments
 (0)