Skip to content

Commit 921b959

Browse files
committed
update
1 parent 9391a54 commit 921b959

6 files changed

Lines changed: 178 additions & 66 deletions

File tree

src/diffusers/modular_pipelines/flux2/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,11 @@
5353
"Flux2AutoBlocks",
5454
"Flux2AutoDecodeStep",
5555
"Flux2AutoDenoiseStep",
56-
"Flux2AutoInputStep",
56+
"Flux2AutoImageInputStep",
5757
"Flux2AutoTextEncoderStep",
58+
"Flux2AutoTextInputStep",
5859
"Flux2AutoVaeEncoderStep",
5960
"Flux2BeforeDenoiseStep",
60-
"Flux2CoreDenoiseStep",
61-
"Flux2InputSequentialStep",
6261
"Flux2VaeEncoderSequentialStep",
6362
]
6463
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
@@ -102,12 +101,11 @@
102101
Flux2AutoBlocks,
103102
Flux2AutoDecodeStep,
104103
Flux2AutoDenoiseStep,
105-
Flux2AutoInputStep,
104+
Flux2AutoImageInputStep,
106105
Flux2AutoTextEncoderStep,
106+
Flux2AutoTextInputStep,
107107
Flux2AutoVaeEncoderStep,
108108
Flux2BeforeDenoiseStep,
109-
Flux2CoreDenoiseStep,
110-
Flux2InputSequentialStep,
111109
Flux2VaeEncoderSequentialStep,
112110
)
113111
from .modular_pipeline import Flux2ModularPipeline

src/diffusers/modular_pipelines/flux2/before_denoise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def intermediate_outputs(self) -> List[OutputParam]:
366366
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
367367
),
368368
OutputParam(
369-
name="img_ids",
369+
name="latent_ids",
370370
kwargs_type="denoiser_input_fields",
371371
type_hint=torch.Tensor,
372372
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",

src/diffusers/modular_pipelines/flux2/modular_blocks.py

Lines changed: 34 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,33 @@
3535
)
3636

3737

38+
class Flux2AutoTextInputStep(AutoPipelineBlocks):
39+
block_classes = [Flux2TextInputStep]
40+
block_names = ["text_input"]
41+
block_trigger_inputs = [None]
42+
43+
@property
44+
def description(self):
45+
return (
46+
"Text input step that processes text embeddings and determines batch size.\n"
47+
" - `Flux2TextInputStep` is always used."
48+
)
49+
50+
51+
class Flux2AutoImageInputStep(AutoPipelineBlocks):
52+
block_classes = [Flux2ImageInputStep]
53+
block_names = ["image_input"]
54+
block_trigger_inputs = ["image_latents"]
55+
56+
@property
57+
def description(self):
58+
return (
59+
"Image input step that expands image latents to match batch size.\n"
60+
" - `Flux2ImageInputStep` is used when `image_latents` is provided.\n"
61+
" - Skipped when no image conditioning is used."
62+
)
63+
64+
3865
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3966

4067

@@ -147,66 +174,14 @@ def description(self):
147174
return "Decode step that decodes the denoised latents into image outputs.\n - `Flux2DecodeStep`"
148175

149176

150-
Flux2InputBlocks = InsertableDict(
151-
[
152-
("text_inputs", Flux2TextInputStep()),
153-
("image_inputs", Flux2ImageInputStep()),
154-
]
155-
)
156-
157-
158-
class Flux2InputSequentialStep(SequentialPipelineBlocks):
159-
model_name = "flux2"
160-
block_classes = Flux2InputBlocks.values()
161-
block_names = Flux2InputBlocks.keys()
162-
163-
@property
164-
def description(self):
165-
return (
166-
"Input step that prepares the inputs for the Flux2 denoising step. It:\n"
167-
" - Makes sure the text embeddings have consistent batch size.\n"
168-
" - Processes image latents if provided."
169-
)
170-
171-
172-
class Flux2AutoInputStep(AutoPipelineBlocks):
173-
block_classes = [Flux2InputSequentialStep, Flux2TextInputStep]
174-
block_names = ["img_conditioning", "text2image"]
175-
block_trigger_inputs = ["image_latents", None]
176-
177-
@property
178-
def description(self):
179-
return (
180-
"Input step that standardizes the inputs for the denoising step.\n"
181-
"This is an auto pipeline block that works for text-to-image/image-conditioned tasks.\n"
182-
" - `Flux2InputSequentialStep` is used when `image_latents` is provided.\n"
183-
" - `Flux2TextInputStep` is used when `image_latents` is not provided.\n"
184-
)
185-
186-
187-
class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
188-
model_name = "flux2"
189-
block_classes = [Flux2AutoInputStep, Flux2AutoBeforeDenoiseStep, Flux2AutoDenoiseStep]
190-
block_names = ["input", "before_denoise", "denoise"]
191-
192-
@property
193-
def description(self):
194-
return (
195-
"Core step that performs the denoising process for Flux2. \n"
196-
" - `Flux2AutoInputStep` (input) standardizes the inputs for the denoising step.\n"
197-
" - `Flux2AutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
198-
" - `Flux2AutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
199-
"This step supports text-to-image and image-conditioned tasks for Flux2:\n"
200-
" - For image-conditioned generation, you need to provide `packed_image_latents`.\n"
201-
" - For text-to-image generation, all you need to provide is prompt embeddings."
202-
)
203-
204-
205177
AUTO_BLOCKS = InsertableDict(
206178
[
207179
("text_encoder", Flux2AutoTextEncoderStep()),
180+
("text_input", Flux2AutoTextInputStep()),
208181
("image_encoder", Flux2AutoVaeEncoderStep()),
209-
("denoise", Flux2CoreDenoiseStep()),
182+
("image_input", Flux2AutoImageInputStep()),
183+
("before_denoise", Flux2AutoBeforeDenoiseStep()),
184+
("denoise", Flux2AutoDenoiseStep()),
210185
("decode", Flux2DecodeStep()),
211186
]
212187
)
@@ -230,7 +205,7 @@ def description(self):
230205
TEXT2IMAGE_BLOCKS = InsertableDict(
231206
[
232207
("text_encoder", Flux2TextEncoderStep()),
233-
("input", Flux2TextInputStep()),
208+
("text_input", Flux2TextInputStep()),
234209
("prepare_latents", Flux2PrepareLatentsStep()),
235210
("set_timesteps", Flux2SetTimestepsStep()),
236211
("prepare_rope_inputs", Flux2RoPEInputsStep()),
@@ -242,10 +217,11 @@ def description(self):
242217
IMAGE_CONDITIONED_BLOCKS = InsertableDict(
243218
[
244219
("text_encoder", Flux2TextEncoderStep()),
220+
("text_input", Flux2TextInputStep()),
245221
("preprocess_images", Flux2ProcessImagesInputStep()),
246222
("vae_encoder", Flux2VaeEncoderStep()),
247223
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
248-
("input", Flux2InputSequentialStep()),
224+
("image_input", Flux2ImageInputStep()),
249225
("prepare_latents", Flux2PrepareLatentsStep()),
250226
("set_timesteps", Flux2SetTimestepsStep()),
251227
("prepare_rope_inputs", Flux2RoPEInputsStep()),

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1586,7 +1586,6 @@ def __init__(
15861586
for name, config_spec in self._config_specs.items():
15871587
default_configs[name] = config_spec.default
15881588
self.register_to_config(**default_configs)
1589-
15901589
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
15911590

15921591
@property

tests/modular_pipelines/flux2/__init__.py

Whitespace-only changes.
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import random
17+
import tempfile
18+
import unittest
19+
20+
import numpy as np
21+
import PIL
22+
import torch
23+
24+
from diffusers.modular_pipelines import (
25+
Flux2AutoBlocks,
26+
Flux2ModularPipeline,
27+
ModularPipeline,
28+
)
29+
from diffusers.modular_pipelines.flux2 import (
30+
Flux2AutoTextEncoderStep,
31+
Flux2RemoteTextEncoderStep,
32+
Flux2TextEncoderStep,
33+
)
34+
35+
from ...testing_utils import floats_tensor, torch_device
36+
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
37+
38+
39+
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
40+
pipeline_class = Flux2ModularPipeline
41+
pipeline_blocks_class = Flux2AutoBlocks
42+
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
43+
44+
params = frozenset(["prompt", "height", "width", "guidance_scale"])
45+
batch_params = frozenset(["prompt"])
46+
47+
def get_dummy_inputs(self, seed=0):
48+
generator = self.get_generator(seed)
49+
inputs = {
50+
"prompt": "A painting of a squirrel eating a burger",
51+
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
52+
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
53+
"text_encoder_out_layers": (1,),
54+
"generator": generator,
55+
"num_inference_steps": 2,
56+
"guidance_scale": 4.0,
57+
"height": 32,
58+
"width": 32,
59+
"output_type": "pt",
60+
}
61+
return inputs
62+
63+
def test_float16_inference(self):
64+
super().test_float16_inference(9e-2)
65+
66+
67+
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
68+
pipeline_class = Flux2ModularPipeline
69+
pipeline_blocks_class = Flux2AutoBlocks
70+
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
71+
72+
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
73+
batch_params = frozenset(["prompt", "image"])
74+
75+
def get_dummy_inputs(self, seed=0):
76+
generator = self.get_generator(seed)
77+
inputs = {
78+
"prompt": "A painting of a squirrel eating a burger",
79+
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
80+
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
81+
"text_encoder_out_layers": (1,),
82+
"generator": generator,
83+
"num_inference_steps": 2,
84+
"guidance_scale": 4.0,
85+
"height": 32,
86+
"width": 32,
87+
"output_type": "pt",
88+
}
89+
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
90+
image = image.cpu().permute(0, 2, 3, 1)[0]
91+
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
92+
inputs["image"] = init_image
93+
94+
return inputs
95+
96+
def test_save_from_pretrained(self):
97+
pipes = []
98+
base_pipe = self.get_pipeline().to(torch_device)
99+
pipes.append(base_pipe)
100+
101+
with tempfile.TemporaryDirectory() as tmpdirname:
102+
base_pipe.save_pretrained(tmpdirname)
103+
104+
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
105+
pipe.load_components(torch_dtype=torch.float32)
106+
pipe.to(torch_device)
107+
108+
pipes.append(pipe)
109+
110+
image_slices = []
111+
for pipe in pipes:
112+
inputs = self.get_dummy_inputs()
113+
image = pipe(**inputs, output="images")
114+
115+
image_slices.append(image[0, -3:, -3:, -1].flatten())
116+
117+
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
118+
119+
def test_float16_inference(self):
120+
super().test_float16_inference(9e-2)
121+
122+
123+
class TestFlux2AutoTextEncoderStep(unittest.TestCase):
124+
def test_auto_text_encoder_block_classes(self):
125+
auto_step = Flux2AutoTextEncoderStep()
126+
127+
assert len(auto_step.block_classes) == 2
128+
assert Flux2RemoteTextEncoderStep in auto_step.block_classes
129+
assert Flux2TextEncoderStep in auto_step.block_classes
130+
131+
def test_auto_text_encoder_trigger_inputs(self):
132+
auto_step = Flux2AutoTextEncoderStep()
133+
134+
assert auto_step.block_trigger_inputs == ["remote_text_encoder", None]
135+
136+
def test_auto_text_encoder_block_names(self):
137+
auto_step = Flux2AutoTextEncoderStep()
138+
139+
assert auto_step.block_names == ["remote", "local"]

0 commit comments

Comments
 (0)