Skip to content

Commit b6cb7b1

Browse files
committed
Fix Flux2 DreamBooth prior preservation prompt repeats
1 parent fbe8a75 commit b6cb7b1

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,9 +1740,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17401740
prompt_embeds = prompt_embeds_cache[step]
17411741
text_ids = text_ids_cache[step]
17421742
else:
1743-
num_repeat_elements = len(prompts)
1744-
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
1745-
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
1743+
# With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
1744+
# while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
1745+
# dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
1746+
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
1747+
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
1748+
text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0)
17461749

17471750
# Convert images to latent space
17481751
if args.cache_latents:

examples/dreambooth/train_dreambooth_lora_flux2_klein.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,9 +1680,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16801680
prompt_embeds = prompt_embeds_cache[step]
16811681
text_ids = text_ids_cache[step]
16821682
else:
1683-
num_repeat_elements = len(prompts)
1684-
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
1685-
text_ids = text_ids.repeat(num_repeat_elements, 1, 1)
1683+
# With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
1684+
# while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
1685+
# dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
1686+
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
1687+
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
1688+
text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0)
16861689

16871690
# Convert images to latent space
16881691
if args.cache_latents:

0 commit comments

Comments
 (0)