Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,8 +1749,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_input = latents_cache[step].mode()
else:
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.mode()
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.mode()

model_input = Flux2Pipeline._patchify_latents(model_input)
model_input = (model_input - latents_bn_mean) / latents_bn_std
Expand Down
9 changes: 4 additions & 5 deletions examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,11 +1686,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
cond_model_input = cond_latents_cache[step].mode()
else:
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)

model_input = vae.encode(pixel_values).latent_dist.mode()
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
cond_pixel_values = batch["cond_pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.mode()
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()

# model_input = Flux2Pipeline._encode_vae_image(pixel_values)

Expand Down
4 changes: 2 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,8 +1689,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_input = latents_cache[step].mode()
else:
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.mode()
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.mode()

model_input = Flux2KleinPipeline._patchify_latents(model_input)
model_input = (model_input - latents_bn_mean) / latents_bn_std
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1634,11 +1634,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
cond_model_input = cond_latents_cache[step].mode()
else:
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)

model_input = vae.encode(pixel_values).latent_dist.mode()
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
cond_pixel_values = batch["cond_pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.mode()
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()

model_input = Flux2KleinPipeline._patchify_latents(model_input)
model_input = (model_input - latents_bn_mean) / latents_bn_std
Expand Down
4 changes: 2 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,8 +1665,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_input = latents_cache[step].mode()
else:
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.mode()
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.mode()

model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
# Sample noise that we'll add to the latents
Expand Down
Loading