Skip to content

Commit 3a00206

Browse files
committed
Fix Flux2 DreamBooth prior preservation prompt repeats
1 parent fbe8a75 commit 3a00206

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,9 +1740,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17401740
prompt_embeds = prompt_embeds_cache[step]
17411741
text_ids = text_ids_cache[step]
17421742
else:
1743-
num_repeat_elements = len(prompts)
1744-
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
1745-
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
1743+
# With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
1744+
# while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
1745+
# dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
1746+
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
1747+
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
1748+
text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0)
17461749

17471750
# Convert images to latent space
17481751
if args.cache_latents:
@@ -1809,10 +1812,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18091812
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
18101813
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
18111814
target, target_prior = torch.chunk(target, 2, dim=0)
1815+
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
18121816

18131817
# Compute prior loss
18141818
prior_loss = torch.mean(
1815-
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
1819+
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
18161820
target_prior.shape[0], -1
18171821
),
18181822
1,

examples/dreambooth/train_dreambooth_lora_flux2_klein.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,9 +1680,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16801680
prompt_embeds = prompt_embeds_cache[step]
16811681
text_ids = text_ids_cache[step]
16821682
else:
1683-
num_repeat_elements = len(prompts)
1684-
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
1685-
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
1683+
# With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
1684+
# while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
1685+
# dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
1686+
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
1687+
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
1688+
text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0)
16861689

16871690
# Convert images to latent space
16881691
if args.cache_latents:
@@ -1752,10 +1755,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17521755
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
17531756
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
17541757
target, target_prior = torch.chunk(target, 2, dim=0)
1758+
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
17551759

17561760
# Compute prior loss
17571761
prior_loss = torch.mean(
1758-
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
1762+
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
17591763
target_prior.shape[0], -1
17601764
),
17611765
1,

0 commit comments

Comments
 (0)