Skip to content

Commit 42a46e4

Browse files
adi776borateyiyixuxuDN6
authored
Fix missing latents_bn_std dtype cast in VAE normalization (#13299)
* Corrected casting of latents_bn_std * Propagated the fix to the klein inpaint pipeline --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 1a8a17b commit 42a46e4

4 files changed

Lines changed: 12 additions & 4 deletions

File tree

src/diffusers/pipelines/flux2/pipeline_flux2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
611611
image_latents = self._patchify_latents(image_latents)
612612

613613
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
614-
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
614+
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
615+
image_latents.device, image_latents.dtype
616+
)
615617
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
616618

617619
return image_latents

src/diffusers/pipelines/flux2/pipeline_flux2_klein.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
467467
image_latents = self._patchify_latents(image_latents)
468468

469469
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
470-
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
470+
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
471+
image_latents.device, image_latents.dtype
472+
)
471473
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
472474

473475
return image_latents

src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
547547
image_latents = self._patchify_latents(image_latents)
548548

549549
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
550-
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
550+
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
551+
image_latents.device, image_latents.dtype
552+
)
551553
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
552554

553555
return image_latents

src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
477477
image_latents = self._patchify_latents(image_latents)
478478

479479
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
480-
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
480+
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
481+
image_latents.device, image_latents.dtype
482+
)
481483
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
482484

483485
return image_latents

0 commit comments

Comments
 (0)