diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 24ba5d507328..9b71c864e6f7 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -1749,8 +1749,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = latents_cache[step].mode() else: with offload_models(vae, device=accelerator.device, offload=args.offload): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) - model_input = vae.encode(pixel_values).latent_dist.mode() + pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.mode() model_input = Flux2Pipeline._patchify_latents(model_input) model_input = (model_input - latents_bn_mean) / latents_bn_std diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index d1396a09b074..f53a28bb34fa 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1686,11 +1686,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input = cond_latents_cache[step].mode() else: with offload_models(vae, device=accelerator.device, offload=args.offload): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) - cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype) - - model_input = vae.encode(pixel_values).latent_dist.mode() - cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() + pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype) + cond_pixel_values = batch["cond_pixel_values"].to(device=accelerator.device, dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.mode() + cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() # model_input = Flux2Pipeline._encode_vae_image(pixel_values) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 942c1317e3a8..2aa5a1c3e30c 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -1689,8 +1689,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = latents_cache[step].mode() else: with offload_models(vae, device=accelerator.device, offload=args.offload): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) - model_input = vae.encode(pixel_values).latent_dist.mode() + pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.mode() model_input = Flux2KleinPipeline._patchify_latents(model_input) model_input = (model_input - latents_bn_mean) / latents_bn_std diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index b19714d666e1..4c1838a0a4e1 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -1634,11 +1634,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input = cond_latents_cache[step].mode() else: with offload_models(vae, device=accelerator.device, offload=args.offload): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) - cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype) - - model_input = vae.encode(pixel_values).latent_dist.mode() - cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() + pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype) + cond_pixel_values = batch["cond_pixel_values"].to(device=accelerator.device, dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.mode() + cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() model_input = Flux2KleinPipeline._patchify_latents(model_input) model_input = (model_input - latents_bn_mean) / latents_bn_std diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py index 623ae4d2aca3..5f2c3b2f637e 100644 --- a/examples/dreambooth/train_dreambooth_lora_z_image.py +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -1665,8 +1665,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = latents_cache[step].mode() else: with offload_models(vae, device=accelerator.device, offload=args.offload): - pixel_values = batch["pixel_values"].to(dtype=vae.dtype) - model_input = vae.encode(pixel_values).latent_dist.mode() + pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.mode() model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor # Sample noise that we'll add to the latents