Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
prompt_embeds = prompt_embeds_cache[step]
prompt_embeds_mask = prompt_embeds_mask_cache[step]
else:
num_repeat_elements = len(prompts)
# With prior preservation, prompt_embeds already contains [instance, class] embeddings
# from the cat above, but collate_fn also doubles the prompts list. Use half the
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
Expand Down