Skip to content

Commit acc07f5

Browse files
chenyangzhu1github-actions[bot]sayakpaul
authored
Handle prompt embedding concat in Qwen dreambooth example (#13387)
* Handle prompt embedding concat in Qwen dreambooth example * remove wandb config * Apply style fixes * add a comment on how this is only relevant during prior preservation. --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 431066e commit acc07f5

1 file changed

Lines changed: 66 additions & 2 deletions

File tree

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,68 @@ def __getitem__(self, index):
906906
return example
907907

908908

909+
# These helpers only matter for prior preservation, where instance and class prompt
910+
# embedding batches are concatenated and may not share the same mask/sequence length.
911+
def _materialize_prompt_embedding_mask(
912+
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None
913+
) -> torch.Tensor:
914+
"""Return a dense mask tensor for a prompt embedding batch."""
915+
batch_size, seq_len = prompt_embeds.shape[:2]
916+
917+
if prompt_embeds_mask is None:
918+
return torch.ones((batch_size, seq_len), dtype=torch.long, device=prompt_embeds.device)
919+
920+
if prompt_embeds_mask.shape != (batch_size, seq_len):
921+
raise ValueError(
922+
f"`prompt_embeds_mask` shape {prompt_embeds_mask.shape} must match prompt embeddings shape "
923+
f"({batch_size}, {seq_len})."
924+
)
925+
926+
return prompt_embeds_mask.to(device=prompt_embeds.device)
927+
928+
929+
def _pad_prompt_embedding_pair(
930+
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None, target_seq_len: int
931+
) -> tuple[torch.Tensor, torch.Tensor]:
932+
"""Pad one prompt embedding batch and its mask to a shared sequence length."""
933+
prompt_embeds_mask = _materialize_prompt_embedding_mask(prompt_embeds, prompt_embeds_mask)
934+
pad_width = target_seq_len - prompt_embeds.shape[1]
935+
936+
if pad_width <= 0:
937+
return prompt_embeds, prompt_embeds_mask
938+
939+
prompt_embeds = torch.cat(
940+
[prompt_embeds, prompt_embeds.new_zeros(prompt_embeds.shape[0], pad_width, prompt_embeds.shape[2])], dim=1
941+
)
942+
prompt_embeds_mask = torch.cat(
943+
[prompt_embeds_mask, prompt_embeds_mask.new_zeros(prompt_embeds_mask.shape[0], pad_width)], dim=1
944+
)
945+
946+
return prompt_embeds, prompt_embeds_mask
947+
948+
949+
def concat_prompt_embedding_batches(
950+
*prompt_embedding_pairs: tuple[torch.Tensor, torch.Tensor | None],
951+
) -> tuple[torch.Tensor, torch.Tensor | None]:
952+
"""Concatenate prompt embedding batches while handling missing masks and length mismatches."""
953+
if not prompt_embedding_pairs:
954+
raise ValueError("At least one prompt embedding pair must be provided.")
955+
956+
target_seq_len = max(prompt_embeds.shape[1] for prompt_embeds, _ in prompt_embedding_pairs)
957+
padded_pairs = [
958+
_pad_prompt_embedding_pair(prompt_embeds, prompt_embeds_mask, target_seq_len)
959+
for prompt_embeds, prompt_embeds_mask in prompt_embedding_pairs
960+
]
961+
962+
merged_prompt_embeds = torch.cat([prompt_embeds for prompt_embeds, _ in padded_pairs], dim=0)
963+
merged_mask = torch.cat([prompt_embeds_mask for _, prompt_embeds_mask in padded_pairs], dim=0)
964+
965+
if merged_mask.all():
966+
return merged_prompt_embeds, None
967+
968+
return merged_prompt_embeds, merged_mask
969+
970+
909971
def main(args):
910972
if args.report_to == "wandb" and args.hub_token is not None:
911973
raise ValueError(
@@ -1320,8 +1382,10 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
13201382
prompt_embeds = instance_prompt_embeds
13211383
prompt_embeds_mask = instance_prompt_embeds_mask
13221384
if args.with_prior_preservation:
1323-
prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0)
1324-
prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0)
1385+
prompt_embeds, prompt_embeds_mask = concat_prompt_embedding_batches(
1386+
(instance_prompt_embeds, instance_prompt_embeds_mask),
1387+
(class_prompt_embeds, class_prompt_embeds_mask),
1388+
)
13251389

13261390
# if cache_latents is set to True, we encode images to latents and store them.
13271391
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided

0 commit comments

Comments
 (0)