|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -import numpy as np |
16 | 15 | import torch |
17 | | -from PIL import Image |
18 | 16 |
|
19 | 17 | from ...configuration_utils import FrozenDict |
| 18 | +from ...image_processor import VaeImageProcessor |
20 | 19 | from ...models import AutoencoderKLFlux2 |
21 | 20 | from ...utils import logging |
22 | 21 | from ..modular_pipeline import ModularPipelineBlocks, PipelineState |
@@ -44,6 +43,12 @@ def expected_components(self) -> list[ComponentSpec]: |
44 | 43 | config=FrozenDict({"patch_size": 2}), |
45 | 44 | default_creation_method="from_config", |
46 | 45 | ), |
| 46 | + ComponentSpec( |
| 47 | + "image_processor", |
| 48 | + VaeImageProcessor, |
| 49 | + config=FrozenDict({"vae_scale_factor": 16}), |
| 50 | + default_creation_method="from_config", |
| 51 | + ), |
47 | 52 | ] |
48 | 53 |
|
49 | 54 | @property |
@@ -81,18 +86,7 @@ def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) |
81 | 86 | latents = components.pachifier.unpack_latents(latents) |
82 | 87 |
|
83 | 88 | images = vae.decode(latents.to(vae.dtype), return_dict=False)[0] |
84 | | - images = (images.clamp(-1, 1) + 1) / 2 |
85 | | - |
86 | | - output_type = block_state.output_type |
87 | | - if output_type == "pt": |
88 | | - block_state.images = images |
89 | | - elif output_type == "np": |
90 | | - block_state.images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
91 | | - elif output_type == "pil": |
92 | | - images_np = images.cpu().permute(0, 2, 3, 1).float().numpy() |
93 | | - block_state.images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images_np] |
94 | | - else: |
95 | | - raise ValueError(f"Unsupported `output_type`: {output_type!r}. Expected one of 'pil', 'np', 'pt'.") |
| 89 | + block_state.images = components.image_processor.postprocess(images, output_type=block_state.output_type) |
96 | 90 |
|
97 | 91 | self.set_block_state(state, block_state) |
98 | 92 | return components, state |
0 commit comments