@@ -906,6 +906,68 @@ def __getitem__(self, index):
906906 return example
907907
908908
909+ # These helpers only matter for prior preservation, where instance and class prompt
910+ # embedding batches are concatenated and may not share the same mask/sequence length.
911+ def _materialize_prompt_embedding_mask (
912+ prompt_embeds : torch .Tensor , prompt_embeds_mask : torch .Tensor | None
913+ ) -> torch .Tensor :
914+ """Return a dense mask tensor for a prompt embedding batch."""
915+ batch_size , seq_len = prompt_embeds .shape [:2 ]
916+
917+ if prompt_embeds_mask is None :
918+ return torch .ones ((batch_size , seq_len ), dtype = torch .long , device = prompt_embeds .device )
919+
920+ if prompt_embeds_mask .shape != (batch_size , seq_len ):
921+ raise ValueError (
922+ f"`prompt_embeds_mask` shape { prompt_embeds_mask .shape } must match prompt embeddings shape "
923+ f"({ batch_size } , { seq_len } )."
924+ )
925+
926+ return prompt_embeds_mask .to (device = prompt_embeds .device )
927+
928+
929+ def _pad_prompt_embedding_pair (
930+ prompt_embeds : torch .Tensor , prompt_embeds_mask : torch .Tensor | None , target_seq_len : int
931+ ) -> tuple [torch .Tensor , torch .Tensor ]:
932+ """Pad one prompt embedding batch and its mask to a shared sequence length."""
933+ prompt_embeds_mask = _materialize_prompt_embedding_mask (prompt_embeds , prompt_embeds_mask )
934+ pad_width = target_seq_len - prompt_embeds .shape [1 ]
935+
936+ if pad_width <= 0 :
937+ return prompt_embeds , prompt_embeds_mask
938+
939+ prompt_embeds = torch .cat (
940+ [prompt_embeds , prompt_embeds .new_zeros (prompt_embeds .shape [0 ], pad_width , prompt_embeds .shape [2 ])], dim = 1
941+ )
942+ prompt_embeds_mask = torch .cat (
943+ [prompt_embeds_mask , prompt_embeds_mask .new_zeros (prompt_embeds_mask .shape [0 ], pad_width )], dim = 1
944+ )
945+
946+ return prompt_embeds , prompt_embeds_mask
947+
948+
949+ def concat_prompt_embedding_batches (
950+ * prompt_embedding_pairs : tuple [torch .Tensor , torch .Tensor | None ],
951+ ) -> tuple [torch .Tensor , torch .Tensor | None ]:
952+ """Concatenate prompt embedding batches while handling missing masks and length mismatches."""
953+ if not prompt_embedding_pairs :
954+ raise ValueError ("At least one prompt embedding pair must be provided." )
955+
956+ target_seq_len = max (prompt_embeds .shape [1 ] for prompt_embeds , _ in prompt_embedding_pairs )
957+ padded_pairs = [
958+ _pad_prompt_embedding_pair (prompt_embeds , prompt_embeds_mask , target_seq_len )
959+ for prompt_embeds , prompt_embeds_mask in prompt_embedding_pairs
960+ ]
961+
962+ merged_prompt_embeds = torch .cat ([prompt_embeds for prompt_embeds , _ in padded_pairs ], dim = 0 )
963+ merged_mask = torch .cat ([prompt_embeds_mask for _ , prompt_embeds_mask in padded_pairs ], dim = 0 )
964+
965+ if merged_mask .all ():
966+ return merged_prompt_embeds , None
967+
968+ return merged_prompt_embeds , merged_mask
969+
970+
909971def main (args ):
910972 if args .report_to == "wandb" and args .hub_token is not None :
911973 raise ValueError (
@@ -1320,8 +1382,10 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
13201382 prompt_embeds = instance_prompt_embeds
13211383 prompt_embeds_mask = instance_prompt_embeds_mask
13221384 if args .with_prior_preservation :
1323- prompt_embeds = torch .cat ([prompt_embeds , class_prompt_embeds ], dim = 0 )
1324- prompt_embeds_mask = torch .cat ([prompt_embeds_mask , class_prompt_embeds_mask ], dim = 0 )
1385+ prompt_embeds , prompt_embeds_mask = concat_prompt_embedding_batches (
1386+ (instance_prompt_embeds , instance_prompt_embeds_mask ),
1387+ (class_prompt_embeds , class_prompt_embeds_mask ),
1388+ )
13251389
13261390 # if cache_latents is set to True, we encode images to latents and store them.
13271391 # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
0 commit comments