1- # copyright 2025 the huggingface team. all rights reserved.
1+ # Copyright 2025 The HuggingFace Team. All rights reserved.
22#
3- # licensed under the apache license, version 2.0 (the "license ");
4- # you may not use this file except in compliance with the license .
5- # you may obtain a copy of the license at
3+ # Licensed under the Apache License, Version 2.0 (the "License ");
4+ # you may not use this file except in compliance with the License .
5+ # You may obtain a copy of the License at
66#
7- # http://www.apache.org/licenses/license -2.0
7+ # http://www.apache.org/licenses/LICENSE -2.0
88#
9- # unless required by applicable law or agreed to in writing, software
10- # distributed under the license is distributed on an "as is" basis ,
11- # without warranties or conditions of any kind , either express or implied.
12- # see the license for the specific language governing permissions and
13- # limitations under the license .
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS ,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License .
1414
15- from typing import list
15+ from typing import List
1616
1717import torch
1818
19- from ...configuration_utils import frozendict
20- from ...pipelines .flux2 .image_processor import flux2imageprocessor
19+ from ...configuration_utils import FrozenDict
20+ from ...pipelines .flux2 .image_processor import Flux2ImageProcessor
2121from ...utils import logging
22- from ..modular_pipeline import modularpipelineblocks , pipelinestate
23- from ..modular_pipeline_utils import componentspec , inputparam , outputparam
24- from .modular_pipeline import flux2modularpipeline
22+ from ..modular_pipeline import ModularPipelineBlocks , PipelineState
23+ from ..modular_pipeline_utils import ComponentSpec , InputParam , OutputParam
24+ from .modular_pipeline import Flux2ModularPipeline
2525
2626
2727logger = logging .get_logger (__name__ )
2828
2929
30- class flux2textinputstep ( modularpipelineblocks ):
30+ class Flux2TextInputStep ( ModularPipelineBlocks ):
3131 model_name = "flux2"
3232
3333 @property
3434 def description (self ) -> str :
3535 return (
36- "this step:\n "
37- " 1. determines `batch_size` and `dtype` based on `prompt_embeds`\n "
38- " 2. ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
36+ "This step:\n "
37+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n "
38+ " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
3939 )
4040
4141 @property
42- def inputs (self ) -> list [ inputparam ]:
42+ def inputs (self ) -> List [ InputParam ]:
4343 return [
44- inputparam ("num_images_per_prompt" , default = 1 ),
45- inputparam (
44+ InputParam ("num_images_per_prompt" , default = 1 ),
45+ InputParam (
4646 "prompt_embeds" ,
47- required = true ,
47+ required = True ,
4848 kwargs_type = "denoiser_input_fields" ,
49- type_hint = torch .tensor ,
50- description = "pre -generated text embeddings from mistral3. can be generated from text_encoder step." ,
49+ type_hint = torch .Tensor ,
50+ description = "Pre -generated text embeddings from Mistral3. Can be generated from text_encoder step." ,
5151 ),
5252 ]
5353
5454 @property
55- def intermediate_outputs (self ) -> list [str ]:
55+ def intermediate_outputs (self ) -> List [str ]:
5656 return [
57- outputparam (
57+ OutputParam (
5858 "batch_size" ,
5959 type_hint = int ,
60- description = "number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt" ,
60+ description = "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt" ,
6161 ),
62- outputparam (
62+ OutputParam (
6363 "dtype" ,
6464 type_hint = torch .dtype ,
65- description = "data type of model tensor inputs (determined by `prompt_embeds`)" ,
65+ description = "Data type of model tensor inputs (determined by `prompt_embeds`)" ,
6666 ),
67- outputparam (
67+ OutputParam (
6868 "prompt_embeds" ,
69- type_hint = torch .tensor ,
69+ type_hint = torch .Tensor ,
7070 kwargs_type = "denoiser_input_fields" ,
71- description = "text embeddings used to guide the image generation" ,
71+ description = "Text embeddings used to guide the image generation" ,
7272 ),
7373 ]
7474
7575 @torch .no_grad ()
76- def __call__ (self , components : flux2modularpipeline , state : pipelinestate ) -> pipelinestate :
76+ def __call__ (self , components : Flux2ModularPipeline , state : PipelineState ) -> PipelineState :
7777 block_state = self .get_block_state (state )
7878
7979 block_state .batch_size = block_state .prompt_embeds .shape [0 ]
@@ -89,70 +89,72 @@ def __call__(self, components: flux2modularpipeline, state: pipelinestate) -> pi
8989 return components , state
9090
9191
92- class flux2processimagesinputstep ( modularpipelineblocks ):
92+ class Flux2ProcessImagesInputStep ( ModularPipelineBlocks ):
9393 model_name = "flux2"
9494
9595 @property
9696 def description (self ) -> str :
97- return "image preprocess step for flux2. validates and preprocesses reference images."
97+ return "Image preprocess step for Flux2. Validates and preprocesses reference images."
9898
9999 @property
100- def expected_components (self ) -> list [ componentspec ]:
100+ def expected_components (self ) -> List [ ComponentSpec ]:
101101 return [
102- componentspec (
102+ ComponentSpec (
103103 "image_processor" ,
104- flux2imageprocessor ,
105- config = frozendict ({"vae_scale_factor" : 16 , "vae_latent_channels" : 32 }),
104+ Flux2ImageProcessor ,
105+ config = FrozenDict ({"vae_scale_factor" : 16 , "vae_latent_channels" : 32 }),
106106 default_creation_method = "from_config" ,
107107 ),
108108 ]
109109
110110 @property
111- def inputs (self ) -> list [ inputparam ]:
111+ def inputs (self ) -> List [ InputParam ]:
112112 return [
113- inputparam ("image" ),
114- inputparam ("height" ),
115- inputparam ("width" ),
113+ InputParam ("image" ),
114+ InputParam ("height" ),
115+ InputParam ("width" ),
116116 ]
117117
118118 @property
119- def intermediate_outputs (self ) -> list [ outputparam ]:
120- return [outputparam (name = "condition_images" , type_hint = list [torch .tensor ])]
119+ def intermediate_outputs (self ) -> List [ OutputParam ]:
120+ return [OutputParam (name = "condition_images" , type_hint = List [torch .Tensor ])]
121121
122122 @torch .no_grad ()
123- def __call__ (self , components : flux2modularpipeline , state : pipelinestate ):
123+ def __call__ (self , components : Flux2ModularPipeline , state : PipelineState ):
124124 block_state = self .get_block_state (state )
125125 images = block_state .image
126126
127- if images is none :
128- block_state .condition_images = none
129- else :
130- if not isinstance (images , list ):
131- images = [images ]
127+ if images is None :
128+ block_state .condition_images = None
129+ self .set_block_state (state , block_state )
130+ return components , state
132131
133- condition_images = []
134- for img in images :
135- components .image_processor .check_image_input (img )
132+ if not isinstance (images , list ):
133+ images = [images ]
136134
135+ condition_images = []
136+ for img in images :
137+ components .image_processor .check_image_input (img )
138+
139+ image_width , image_height = img .size
140+ if image_width * image_height > 1024 * 1024 :
141+ img = components .image_processor ._resize_to_target_area (img , 1024 * 1024 )
137142 image_width , image_height = img .size
138- if image_width * image_height > 1024 * 1024 :
139- img = components .image_processor ._resize_to_target_area (img , 1024 * 1024 )
140- image_width , image_height = img .size
141-
142- multiple_of = components .vae_scale_factor * 2
143- image_width = (image_width // multiple_of ) * multiple_of
144- image_height = (image_height // multiple_of ) * multiple_of
145- condition_img = components .image_processor .preprocess (
146- img , height = image_height , width = image_width , resize_mode = "crop"
147- )
148- condition_images .append (condition_img )
149-
150- if block_state .height is none :
151- block_state .height = image_height
152- if block_state .width is none :
153- block_state .width = image_width
154-
155- block_state .condition_images = condition_images
143+
144+ multiple_of = components .vae_scale_factor * 2
145+ image_width = (image_width // multiple_of ) * multiple_of
146+ image_height = (image_height // multiple_of ) * multiple_of
147+ condition_img = components .image_processor .preprocess (
148+ img , height = image_height , width = image_width , resize_mode = "crop"
149+ )
150+ condition_images .append (condition_img )
151+
152+ if block_state .height is None :
153+ block_state .height = image_height
154+ if block_state .width is None :
155+ block_state .width = image_width
156+
157+ block_state .condition_images = condition_images
156158
157159 self .set_block_state (state , block_state )
158160 return components , state
0 commit comments