Skip to content

Commit b80d3f6

Browse files
fix(qwen-image dreambooth): correct prompt embed repeats when using --with_prior_preservation (#13396)
fix(qwen): correct prompt embed repeats with prior preservation Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent acc07f5 commit b80d3f6

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1529,7 +1529,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15291529
prompt_embeds = prompt_embeds_cache[step]
15301530
prompt_embeds_mask = prompt_embeds_mask_cache[step]
15311531
else:
1532-
num_repeat_elements = len(prompts)
1532+
# With prior preservation, prompt_embeds already contains [instance, class] embeddings
1533+
# from the cat above, but collate_fn also doubles the prompts list. Use half the
1534+
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
1535+
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
15331536
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
15341537
if prompt_embeds_mask is not None:
15351538
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)

0 commit comments

Comments
 (0)