3535)
3636
3737
38+ class Flux2AutoTextInputStep (AutoPipelineBlocks ):
39+ block_classes = [Flux2TextInputStep ]
40+ block_names = ["text_input" ]
41+ block_trigger_inputs = [None ]
42+
43+ @property
44+ def description (self ):
45+ return (
46+ "Text input step that processes text embeddings and determines batch size.\n "
47+ " - `Flux2TextInputStep` is always used."
48+ )
49+
50+
51+ class Flux2AutoImageInputStep (AutoPipelineBlocks ):
52+ block_classes = [Flux2ImageInputStep ]
53+ block_names = ["image_input" ]
54+ block_trigger_inputs = ["image_latents" ]
55+
56+ @property
57+ def description (self ):
58+ return (
59+ "Image input step that expands image latents to match batch size.\n "
60+ " - `Flux2ImageInputStep` is used when `image_latents` is provided.\n "
61+ " - Skipped when no image conditioning is used."
62+ )
63+
64+
3865logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3966
4067
@@ -147,66 +174,14 @@ def description(self):
147174 return "Decode step that decodes the denoised latents into image outputs.\n - `Flux2DecodeStep`"
148175
149176
150- Flux2InputBlocks = InsertableDict (
151- [
152- ("text_inputs" , Flux2TextInputStep ()),
153- ("image_inputs" , Flux2ImageInputStep ()),
154- ]
155- )
156-
157-
158- class Flux2InputSequentialStep (SequentialPipelineBlocks ):
159- model_name = "flux2"
160- block_classes = Flux2InputBlocks .values ()
161- block_names = Flux2InputBlocks .keys ()
162-
163- @property
164- def description (self ):
165- return (
166- "Input step that prepares the inputs for the Flux2 denoising step. It:\n "
167- " - Makes sure the text embeddings have consistent batch size.\n "
168- " - Processes image latents if provided."
169- )
170-
171-
172- class Flux2AutoInputStep (AutoPipelineBlocks ):
173- block_classes = [Flux2InputSequentialStep , Flux2TextInputStep ]
174- block_names = ["img_conditioning" , "text2image" ]
175- block_trigger_inputs = ["image_latents" , None ]
176-
177- @property
178- def description (self ):
179- return (
180- "Input step that standardizes the inputs for the denoising step.\n "
181- "This is an auto pipeline block that works for text-to-image/image-conditioned tasks.\n "
182- " - `Flux2InputSequentialStep` is used when `image_latents` is provided.\n "
183- " - `Flux2TextInputStep` is used when `image_latents` is not provided.\n "
184- )
185-
186-
187- class Flux2CoreDenoiseStep (SequentialPipelineBlocks ):
188- model_name = "flux2"
189- block_classes = [Flux2AutoInputStep , Flux2AutoBeforeDenoiseStep , Flux2AutoDenoiseStep ]
190- block_names = ["input" , "before_denoise" , "denoise" ]
191-
192- @property
193- def description (self ):
194- return (
195- "Core step that performs the denoising process for Flux2. \n "
196- " - `Flux2AutoInputStep` (input) standardizes the inputs for the denoising step.\n "
197- " - `Flux2AutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n "
198- " - `Flux2AutoDenoiseStep` (denoise) iteratively denoises the latents.\n "
199- "This step supports text-to-image and image-conditioned tasks for Flux2:\n "
200- " - For image-conditioned generation, you need to provide `packed_image_latents`.\n "
201- " - For text-to-image generation, all you need to provide is prompt embeddings."
202- )
203-
204-
205177AUTO_BLOCKS = InsertableDict (
206178 [
207179 ("text_encoder" , Flux2AutoTextEncoderStep ()),
180+ ("text_input" , Flux2AutoTextInputStep ()),
208181 ("image_encoder" , Flux2AutoVaeEncoderStep ()),
209- ("denoise" , Flux2CoreDenoiseStep ()),
182+ ("image_input" , Flux2AutoImageInputStep ()),
183+ ("before_denoise" , Flux2AutoBeforeDenoiseStep ()),
184+ ("denoise" , Flux2AutoDenoiseStep ()),
210185 ("decode" , Flux2DecodeStep ()),
211186 ]
212187)
@@ -230,7 +205,7 @@ def description(self):
230205TEXT2IMAGE_BLOCKS = InsertableDict (
231206 [
232207 ("text_encoder" , Flux2TextEncoderStep ()),
233- ("input " , Flux2TextInputStep ()),
208+ ("text_input " , Flux2TextInputStep ()),
234209 ("prepare_latents" , Flux2PrepareLatentsStep ()),
235210 ("set_timesteps" , Flux2SetTimestepsStep ()),
236211 ("prepare_rope_inputs" , Flux2RoPEInputsStep ()),
@@ -242,10 +217,11 @@ def description(self):
242217IMAGE_CONDITIONED_BLOCKS = InsertableDict (
243218 [
244219 ("text_encoder" , Flux2TextEncoderStep ()),
220+ ("text_input" , Flux2TextInputStep ()),
245221 ("preprocess_images" , Flux2ProcessImagesInputStep ()),
246222 ("vae_encoder" , Flux2VaeEncoderStep ()),
247223 ("prepare_image_latents" , Flux2PrepareImageLatentsStep ()),
248- ("input " , Flux2InputSequentialStep ()),
224+ ("image_input " , Flux2ImageInputStep ()),
249225 ("prepare_latents" , Flux2PrepareLatentsStep ()),
250226 ("set_timesteps" , Flux2SetTimestepsStep ()),
251227 ("prepare_rope_inputs" , Flux2RoPEInputsStep ()),
0 commit comments