Skip to content

Commit 0fcfb2c

Browse files
authored
Merge branch 'main' into fix-qwenimage-rope-sync
2 parents 3ff78fd + b80d3f6 commit 0fcfb2c

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1529,7 +1529,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15291529
prompt_embeds = prompt_embeds_cache[step]
15301530
prompt_embeds_mask = prompt_embeds_mask_cache[step]
15311531
else:
1532-
num_repeat_elements = len(prompts)
1532+
# With prior preservation, prompt_embeds already contains [instance, class] embeddings
1533+
# from the cat above, but collate_fn also doubles the prompts list. Use half the
1534+
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
1535+
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
15331536
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
15341537
if prompt_embeds_mask is not None:
15351538
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)

0 commit comments

Comments
 (0)