|
24 | 24 | from ...models.transformers import ZImageTransformer2DModel |
25 | 25 | from ...pipelines.pipeline_utils import DiffusionPipeline |
26 | 26 | from ...schedulers import FlowMatchEulerDiscreteScheduler |
27 | | -from ...utils import logging, replace_example_docstring |
| 27 | +from ...utils import is_torch_xla_available, logging, replace_example_docstring |
28 | 28 | from ...utils.torch_utils import randn_tensor |
29 | 29 | from .pipeline_output import ZImagePipelineOutput |
30 | 30 |
|
|
35 | 35 | XLA_AVAILABLE = True |
36 | 36 | else: |
37 | 37 | XLA_AVAILABLE = False |
38 | | - |
| 38 | + |
39 | 39 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
40 | 40 |
|
41 | 41 | EXAMPLE_DOC_STRING = """ |
@@ -476,6 +476,57 @@ def num_timesteps(self): |
476 | 476 | def interrupt(self): |
477 | 477 | return self._interrupt |
478 | 478 |
|
| 479 | + def check_inputs( |
| 480 | + self, |
| 481 | + prompt, |
| 482 | + image, |
| 483 | + mask_image, |
| 484 | + strength, |
| 485 | + height, |
| 486 | + width, |
| 487 | + output_type, |
| 488 | + negative_prompt=None, |
| 489 | + prompt_embeds=None, |
| 490 | + negative_prompt_embeds=None, |
| 491 | + callback_on_step_end_tensor_inputs=None, |
| 492 | + ): |
| 493 | + if strength < 0 or strength > 1: |
| 494 | + raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") |
| 495 | + |
| 496 | + if callback_on_step_end_tensor_inputs is not None and not all( |
| 497 | + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs |
| 498 | + ): |
| 499 | + raise ValueError( |
| 500 | + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" |
| 501 | + ) |
| 502 | + |
| 503 | + if prompt is not None and prompt_embeds is not None: |
| 504 | + raise ValueError( |
| 505 | + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
| 506 | + " only forward one of the two." |
| 507 | + ) |
| 508 | + elif prompt is None and prompt_embeds is None: |
| 509 | + raise ValueError( |
| 510 | + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." |
| 511 | + ) |
| 512 | + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
| 513 | + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
| 514 | + |
| 515 | + if negative_prompt is not None and negative_prompt_embeds is not None: |
| 516 | + raise ValueError( |
| 517 | + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" |
| 518 | + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
| 519 | + ) |
| 520 | + |
| 521 | + if image is None: |
| 522 | + raise ValueError("`image` input cannot be undefined for inpainting.") |
| 523 | + |
| 524 | + if mask_image is None: |
| 525 | + raise ValueError("`mask_image` input cannot be undefined for inpainting.") |
| 526 | + |
| 527 | + if output_type not in ["latent", "pil", "np", "pt"]: |
| 528 | + raise ValueError(f"`output_type` must be one of 'latent', 'pil', 'np', or 'pt', but got {output_type}") |
| 529 | + |
479 | 530 | @torch.no_grad() |
480 | 531 | @replace_example_docstring(EXAMPLE_DOC_STRING) |
481 | 532 | def __call__( |
@@ -598,9 +649,20 @@ def __call__( |
598 | 649 | `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the |
599 | 650 | generated images. |
600 | 651 | """ |
601 | | - # 1. Check inputs and validate strength |
602 | | - if strength < 0 or strength > 1: |
603 | | - raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") |
| 652 | + # 1. Check inputs |
| 653 | + self.check_inputs( |
| 654 | + prompt=prompt, |
| 655 | + image=image, |
| 656 | + mask_image=mask_image, |
| 657 | + strength=strength, |
| 658 | + height=height, |
| 659 | + width=width, |
| 660 | + output_type=output_type, |
| 661 | + negative_prompt=negative_prompt, |
| 662 | + prompt_embeds=prompt_embeds, |
| 663 | + negative_prompt_embeds=negative_prompt_embeds, |
| 664 | + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, |
| 665 | + ) |
604 | 666 |
|
605 | 667 | # 2. Preprocess image and mask |
606 | 668 | init_image = self.image_processor.preprocess(image) |
@@ -848,6 +910,9 @@ def __call__( |
848 | 910 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
849 | 911 | progress_bar.update() |
850 | 912 |
|
| 913 | + if XLA_AVAILABLE: |
| 914 | + xm.mark_step() |
| 915 | + |
851 | 916 | if output_type == "latent": |
852 | 917 | image = latents |
853 | 918 |
|
|
0 commit comments