@@ -1740,9 +1740,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17401740 prompt_embeds = prompt_embeds_cache [step ]
17411741 text_ids = text_ids_cache [step ]
17421742 else :
1743- num_repeat_elements = len (prompts )
1744- prompt_embeds = prompt_embeds .repeat (num_repeat_elements , 1 , 1 )
1745- text_ids = text_ids .repeat (num_repeat_elements , 1 , 1 )
1743+ # With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
1744+ # while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
1745+ # dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
1746+ num_repeat_elements = len (prompts ) // 2 if args .with_prior_preservation else len (prompts )
1747+ prompt_embeds = prompt_embeds .repeat_interleave (num_repeat_elements , dim = 0 )
1748+ text_ids = text_ids .repeat_interleave (num_repeat_elements , dim = 0 )
17461749
17471750 # Convert images to latent space
17481751 if args .cache_latents :
@@ -1809,10 +1812,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18091812 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
18101813 model_pred , model_pred_prior = torch .chunk (model_pred , 2 , dim = 0 )
18111814 target , target_prior = torch .chunk (target , 2 , dim = 0 )
1815+ weighting , weighting_prior = torch .chunk (weighting , 2 , dim = 0 )
18121816
18131817 # Compute prior loss
18141818 prior_loss = torch .mean (
1815- (weighting .float () * (model_pred_prior .float () - target_prior .float ()) ** 2 ).reshape (
1819+ (weighting_prior .float () * (model_pred_prior .float () - target_prior .float ()) ** 2 ).reshape (
18161820 target_prior .shape [0 ], - 1
18171821 ),
18181822 1 ,
0 commit comments