Skip to content

Commit 49b78c7

Browse files
committed
fix(qwen): correct prompt embed repeats with prior preservation
1 parent 8070f6e commit 49b78c7

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1465,7 +1465,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14651465
prompt_embeds = prompt_embeds_cache[step]
14661466
prompt_embeds_mask = prompt_embeds_mask_cache[step]
14671467
else:
1468-
num_repeat_elements = len(prompts)
1468+
# With prior preservation, prompt_embeds already contains [instance, class] embeddings
1469+
# from the cat above, but collate_fn also doubles the prompts list. Use half the
1470+
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
1471+
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
14691472
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
14701473
if prompt_embeds_mask is not None:
14711474
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)

0 commit comments

Comments
 (0)