Skip to content

Commit fb48046

Browse files
Address PR review feedback for ZImageInpaintPipeline
Add batch size validation and callback handling fixes per review, using diffusers conventions rather than suggested code verbatim.
1 parent 94e653d commit fb48046

1 file changed

Lines changed: 14 additions & 0 deletions

File tree

src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,20 @@ def prepare_mask_latents(
353353

354354
# Expand for batch size
355355
if mask.shape[0] < batch_size:
356+
if not batch_size % mask.shape[0] == 0:
357+
raise ValueError(
358+
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
359+
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
360+
" of masks that you pass is divisible by the total requested batch size."
361+
)
356362
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
357363
if masked_image_latents.shape[0] < batch_size:
364+
if not batch_size % masked_image_latents.shape[0] == 0:
365+
raise ValueError(
366+
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
367+
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
368+
" Make sure the number of images that you pass is divisible by the total requested batch size."
369+
)
358370
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
359371

360372
return mask, masked_image_latents
@@ -822,6 +834,8 @@ def __call__(
822834
latents = callback_outputs.pop("latents", latents)
823835
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
824836
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
837+
mask = callback_outputs.pop("mask", mask)
838+
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
825839

826840
# call the callback, if provided
827841
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

0 commit comments

Comments
 (0)