|
16 | 16 |
|
17 | 17 | import torch |
18 | 18 |
|
| 19 | +from ...configuration_utils import FrozenDict |
| 20 | +from ...pipelines.flux2.image_processor import Flux2ImageProcessor |
19 | 21 | from ...utils import logging |
20 | 22 | from ..modular_pipeline import ModularPipelineBlocks, PipelineState |
21 | | -from ..modular_pipeline_utils import InputParam, OutputParam |
| 23 | +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
22 | 24 | from .modular_pipeline import Flux2ModularPipeline |
23 | 25 |
|
24 | 26 |
|
@@ -85,3 +87,72 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi |
85 | 87 |
|
86 | 88 | self.set_block_state(state, block_state) |
87 | 89 | return components, state |
| 90 | + |
| 91 | + |
| 92 | +class Flux2ProcessImagesInputStep(ModularPipelineBlocks): |
| 93 | + model_name = "flux2" |
| 94 | + |
| 95 | + @property |
| 96 | + def description(self) -> str: |
| 97 | + return "Image preprocess step for Flux2. Validates and preprocesses reference images." |
| 98 | + |
| 99 | + @property |
| 100 | + def expected_components(self) -> List[ComponentSpec]: |
| 101 | + return [ |
| 102 | + ComponentSpec( |
| 103 | + "image_processor", |
| 104 | + Flux2ImageProcessor, |
| 105 | + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), |
| 106 | + default_creation_method="from_config", |
| 107 | + ), |
| 108 | + ] |
| 109 | + |
| 110 | + @property |
| 111 | + def inputs(self) -> List[InputParam]: |
| 112 | + return [ |
| 113 | + InputParam("image"), |
| 114 | + InputParam("height"), |
| 115 | + InputParam("width"), |
| 116 | + ] |
| 117 | + |
| 118 | + @property |
| 119 | + def intermediate_outputs(self) -> List[OutputParam]: |
| 120 | + return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])] |
| 121 | + |
| 122 | + @torch.no_grad() |
| 123 | + def __call__(self, components: Flux2ModularPipeline, state: PipelineState): |
| 124 | + block_state = self.get_block_state(state) |
| 125 | + images = block_state.image |
| 126 | + |
| 127 | + if images is None: |
| 128 | + block_state.condition_images = None |
| 129 | + else: |
| 130 | + if not isinstance(images, list): |
| 131 | + images = [images] |
| 132 | + |
| 133 | + condition_images = [] |
| 134 | + for img in images: |
| 135 | + components.image_processor.check_image_input(img) |
| 136 | + |
| 137 | + image_width, image_height = img.size |
| 138 | + if image_width * image_height > 1024 * 1024: |
| 139 | + img = components.image_processor._resize_to_target_area(img, 1024 * 1024) |
| 140 | + image_width, image_height = img.size |
| 141 | + |
| 142 | + multiple_of = components.vae_scale_factor * 2 |
| 143 | + image_width = (image_width // multiple_of) * multiple_of |
| 144 | + image_height = (image_height // multiple_of) * multiple_of |
| 145 | + condition_img = components.image_processor.preprocess( |
| 146 | + img, height=image_height, width=image_width, resize_mode="crop" |
| 147 | + ) |
| 148 | + condition_images.append(condition_img) |
| 149 | + |
| 150 | + if block_state.height is None: |
| 151 | + block_state.height = image_height |
| 152 | + if block_state.width is None: |
| 153 | + block_state.width = image_width |
| 154 | + |
| 155 | + block_state.condition_images = condition_images |
| 156 | + |
| 157 | + self.set_block_state(state, block_state) |
| 158 | + return components, state |
0 commit comments