Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/diffusers/pipelines/flux2/pipeline_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
image_latents = self._patchify_latents(image_latents)

latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
image_latents.device, image_latents.dtype
)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std

return image_latents
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/pipelines/flux2/pipeline_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
image_latents = self._patchify_latents(image_latents)

latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
image_latents.device, image_latents.dtype
)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std

return image_latents
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
image_latents = self._patchify_latents(image_latents)

latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
image_latents.device, image_latents.dtype
)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std

return image_latents
Expand Down
Loading