Skip to content

Commit d31061b

Browse files
Fix VAE offload encode device mismatch in DreamBooth scripts (#13417)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent ee3c352 commit d31061b

File tree

5 files changed

+14
-16
lines changed

5 files changed

+14
-16
lines changed

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,8 +1749,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17491749
model_input = latents_cache[step].mode()
17501750
else:
17511751
with offload_models(vae, device=accelerator.device, offload=args.offload):
1752-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1753-
model_input = vae.encode(pixel_values).latent_dist.mode()
1752+
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
1753+
model_input = vae.encode(pixel_values).latent_dist.mode()
17541754

17551755
model_input = Flux2Pipeline._patchify_latents(model_input)
17561756
model_input = (model_input - latents_bn_mean) / latents_bn_std

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,11 +1686,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16861686
cond_model_input = cond_latents_cache[step].mode()
16871687
else:
16881688
with offload_models(vae, device=accelerator.device, offload=args.offload):
1689-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1690-
cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
1691-
1692-
model_input = vae.encode(pixel_values).latent_dist.mode()
1693-
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
1689+
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
1690+
cond_pixel_values = batch["cond_pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
1691+
model_input = vae.encode(pixel_values).latent_dist.mode()
1692+
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
16941693

16951694
# model_input = Flux2Pipeline._encode_vae_image(pixel_values)
16961695

examples/dreambooth/train_dreambooth_lora_flux2_klein.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,8 +1689,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16891689
model_input = latents_cache[step].mode()
16901690
else:
16911691
with offload_models(vae, device=accelerator.device, offload=args.offload):
1692-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1693-
model_input = vae.encode(pixel_values).latent_dist.mode()
1692+
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
1693+
model_input = vae.encode(pixel_values).latent_dist.mode()
16941694

16951695
model_input = Flux2KleinPipeline._patchify_latents(model_input)
16961696
model_input = (model_input - latents_bn_mean) / latents_bn_std

examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,11 +1634,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16341634
cond_model_input = cond_latents_cache[step].mode()
16351635
else:
16361636
with offload_models(vae, device=accelerator.device, offload=args.offload):
1637-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1638-
cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
1639-
1640-
model_input = vae.encode(pixel_values).latent_dist.mode()
1641-
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
1637+
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
1638+
cond_pixel_values = batch["cond_pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
1639+
model_input = vae.encode(pixel_values).latent_dist.mode()
1640+
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
16421641

16431642
model_input = Flux2KleinPipeline._patchify_latents(model_input)
16441643
model_input = (model_input - latents_bn_mean) / latents_bn_std

examples/dreambooth/train_dreambooth_lora_z_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,8 +1665,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16651665
model_input = latents_cache[step].mode()
16661666
else:
16671667
with offload_models(vae, device=accelerator.device, offload=args.offload):
1668-
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1669-
model_input = vae.encode(pixel_values).latent_dist.mode()
1668+
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
1669+
model_input = vae.encode(pixel_values).latent_dist.mode()
16701670

16711671
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
16721672
# Sample noise that we'll add to the latents

0 commit comments

Comments
 (0)