Skip to content

Commit 387a471

Browse files
examples/dreambooth: fix missing weighting chunk when using prior preservation in Flux and SD3 LoRA training (#13743)
* examples/dreambooth: chunk weighting tensor alongside model_pred and target when using prior preservation (flux LoRA) * examples/dreambooth: chunk weighting tensor alongside model_pred and target when using prior preservation (SD3 LoRA) --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 2f4a717 commit 387a471

2 files changed

Lines changed: 4 additions & 2 deletions

File tree

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1823,10 +1823,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18231823
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
18241824
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
18251825
target, target_prior = torch.chunk(target, 2, dim=0)
1826+
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
18261827

18271828
# Compute prior loss
18281829
prior_loss = torch.mean(
1829-
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
1830+
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
18301831
target_prior.shape[0], -1
18311832
),
18321833
1,

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1824,10 +1824,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18241824
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
18251825
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
18261826
target, target_prior = torch.chunk(target, 2, dim=0)
1827+
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
18271828

18281829
# Compute prior loss
18291830
prior_loss = torch.mean(
1830-
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
1831+
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
18311832
target_prior.shape[0], -1
18321833
),
18331834
1,

0 commit comments

Comments
 (0)