@@ -353,8 +353,20 @@ def prepare_mask_latents(
353353
354354 # Expand for batch size
355355 if mask .shape [0 ] < batch_size :
356+ if not batch_size % mask .shape [0 ] == 0 :
357+ raise ValueError (
358+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
359+ f" a total batch size of { batch_size } , but { mask .shape [0 ]} masks were passed. Make sure the number"
360+ " of masks that you pass is divisible by the total requested batch size."
361+ )
356362 mask = mask .repeat (batch_size // mask .shape [0 ], 1 , 1 , 1 )
357363 if masked_image_latents .shape [0 ] < batch_size :
364+ if not batch_size % masked_image_latents .shape [0 ] == 0 :
365+ raise ValueError (
366+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
367+ f" to a total batch size of { batch_size } , but { masked_image_latents .shape [0 ]} images were passed."
368+ " Make sure the number of images that you pass is divisible by the total requested batch size."
369+ )
358370 masked_image_latents = masked_image_latents .repeat (batch_size // masked_image_latents .shape [0 ], 1 , 1 , 1 )
359371
360372 return mask , masked_image_latents
@@ -822,6 +834,8 @@ def __call__(
822834 latents = callback_outputs .pop ("latents" , latents )
823835 prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
824836 negative_prompt_embeds = callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
837+ mask = callback_outputs .pop ("mask" , mask )
838+ masked_image_latents = callback_outputs .pop ("masked_image_latents" , masked_image_latents )
825839
826840 # call the callback, if provided
827841 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
0 commit comments