Skip to content

Commit bf92e45

Browse files
Add input validation and fix XLA support for ZImageInpaintPipeline
- Add missing is_torch_xla_available import for TPU support - Add xm.mark_step() in denoising loop for proper XLA execution - Add check_inputs() method for comprehensive input validation - Call check_inputs() at the start of __call__ Addresses PR review feedback from @asomoza.
1 parent 79320a6 commit bf92e45

1 file changed

Lines changed: 70 additions & 5 deletions

File tree

src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...models.transformers import ZImageTransformer2DModel
2525
from ...pipelines.pipeline_utils import DiffusionPipeline
2626
from ...schedulers import FlowMatchEulerDiscreteScheduler
27-
from ...utils import logging, replace_example_docstring
27+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
2828
from ...utils.torch_utils import randn_tensor
2929
from .pipeline_output import ZImagePipelineOutput
3030

@@ -35,7 +35,7 @@
3535
XLA_AVAILABLE = True
3636
else:
3737
XLA_AVAILABLE = False
38-
38+
3939
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4040

4141
EXAMPLE_DOC_STRING = """
@@ -476,6 +476,57 @@ def num_timesteps(self):
476476
def interrupt(self):
477477
return self._interrupt
478478

479+
def check_inputs(
480+
self,
481+
prompt,
482+
image,
483+
mask_image,
484+
strength,
485+
height,
486+
width,
487+
output_type,
488+
negative_prompt=None,
489+
prompt_embeds=None,
490+
negative_prompt_embeds=None,
491+
callback_on_step_end_tensor_inputs=None,
492+
):
493+
if strength < 0 or strength > 1:
494+
raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
495+
496+
if callback_on_step_end_tensor_inputs is not None and not all(
497+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
498+
):
499+
raise ValueError(
500+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
501+
)
502+
503+
if prompt is not None and prompt_embeds is not None:
504+
raise ValueError(
505+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
506+
" only forward one of the two."
507+
)
508+
elif prompt is None and prompt_embeds is None:
509+
raise ValueError(
510+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
511+
)
512+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
513+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
514+
515+
if negative_prompt is not None and negative_prompt_embeds is not None:
516+
raise ValueError(
517+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
518+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
519+
)
520+
521+
if image is None:
522+
raise ValueError("`image` input cannot be undefined for inpainting.")
523+
524+
if mask_image is None:
525+
raise ValueError("`mask_image` input cannot be undefined for inpainting.")
526+
527+
if output_type not in ["latent", "pil", "np", "pt"]:
528+
raise ValueError(f"`output_type` must be one of 'latent', 'pil', 'np', or 'pt', but got {output_type}")
529+
479530
@torch.no_grad()
480531
@replace_example_docstring(EXAMPLE_DOC_STRING)
481532
def __call__(
@@ -598,9 +649,20 @@ def __call__(
598649
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
599650
generated images.
600651
"""
601-
# 1. Check inputs and validate strength
602-
if strength < 0 or strength > 1:
603-
raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
652+
# 1. Check inputs
653+
self.check_inputs(
654+
prompt=prompt,
655+
image=image,
656+
mask_image=mask_image,
657+
strength=strength,
658+
height=height,
659+
width=width,
660+
output_type=output_type,
661+
negative_prompt=negative_prompt,
662+
prompt_embeds=prompt_embeds,
663+
negative_prompt_embeds=negative_prompt_embeds,
664+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
665+
)
604666

605667
# 2. Preprocess image and mask
606668
init_image = self.image_processor.preprocess(image)
@@ -848,6 +910,9 @@ def __call__(
848910
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
849911
progress_bar.update()
850912

913+
if XLA_AVAILABLE:
914+
xm.mark_step()
915+
851916
if output_type == "latent":
852917
image = latents
853918

0 commit comments

Comments
 (0)