Skip to content

Commit 771512a

Browse files
committed
update
1 parent b0f50c6 commit 771512a

2 files changed

Lines changed: 73 additions & 2 deletions

File tree

src/diffusers/modular_pipelines/flux2/inputs.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616

1717
import torch
1818

19+
from ...configuration_utils import FrozenDict
20+
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
1921
from ...utils import logging
2022
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
21-
from ..modular_pipeline_utils import InputParam, OutputParam
23+
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
2224
from .modular_pipeline import Flux2ModularPipeline
2325

2426

@@ -85,3 +87,72 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
8587

8688
self.set_block_state(state, block_state)
8789
return components, state
90+
91+
92+
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
93+
model_name = "flux2"
94+
95+
@property
96+
def description(self) -> str:
97+
return "Image preprocess step for Flux2. Validates and preprocesses reference images."
98+
99+
@property
100+
def expected_components(self) -> List[ComponentSpec]:
101+
return [
102+
ComponentSpec(
103+
"image_processor",
104+
Flux2ImageProcessor,
105+
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
106+
default_creation_method="from_config",
107+
),
108+
]
109+
110+
@property
111+
def inputs(self) -> List[InputParam]:
112+
return [
113+
InputParam("image"),
114+
InputParam("height"),
115+
InputParam("width"),
116+
]
117+
118+
@property
119+
def intermediate_outputs(self) -> List[OutputParam]:
120+
return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])]
121+
122+
@torch.no_grad()
123+
def __call__(self, components: Flux2ModularPipeline, state: PipelineState):
124+
block_state = self.get_block_state(state)
125+
images = block_state.image
126+
127+
if images is None:
128+
block_state.condition_images = None
129+
else:
130+
if not isinstance(images, list):
131+
images = [images]
132+
133+
condition_images = []
134+
for img in images:
135+
components.image_processor.check_image_input(img)
136+
137+
image_width, image_height = img.size
138+
if image_width * image_height > 1024 * 1024:
139+
img = components.image_processor._resize_to_target_area(img, 1024 * 1024)
140+
image_width, image_height = img.size
141+
142+
multiple_of = components.vae_scale_factor * 2
143+
image_width = (image_width // multiple_of) * multiple_of
144+
image_height = (image_height // multiple_of) * multiple_of
145+
condition_img = components.image_processor.preprocess(
146+
img, height=image_height, width=image_width, resize_mode="crop"
147+
)
148+
condition_images.append(condition_img)
149+
150+
if block_state.height is None:
151+
block_state.height = image_height
152+
if block_state.width is None:
153+
block_state.width = image_width
154+
155+
block_state.condition_images = condition_images
156+
157+
self.set_block_state(state, block_state)
158+
return components, state

src/diffusers/modular_pipelines/flux2/modular_blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
from .decoders import Flux2DecodeStep
2525
from .denoise import Flux2DenoiseStep
2626
from .encoders import (
27-
Flux2ProcessImagesInputStep,
2827
Flux2RemoteTextEncoderStep,
2928
Flux2TextEncoderStep,
3029
Flux2VaeEncoderStep,
3130
)
3231
from .inputs import (
32+
Flux2ProcessImagesInputStep,
3333
Flux2TextInputStep,
3434
)
3535

0 commit comments

Comments
 (0)