@@ -129,17 +129,9 @@ def inputs(self) -> List[InputParam]:
129129 InputParam ("num_inference_steps" , default = 50 ),
130130 InputParam ("timesteps" ),
131131 InputParam ("sigmas" ),
132- InputParam ("guidance_scale" , default = 4.0 ),
133132 InputParam ("latents" , type_hint = torch .Tensor ),
134- InputParam ("num_images_per_prompt" , default = 1 ),
135133 InputParam ("height" , type_hint = int ),
136134 InputParam ("width" , type_hint = int ),
137- InputParam (
138- "batch_size" ,
139- required = True ,
140- type_hint = int ,
141- description = "Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`." ,
142- ),
143135 ]
144136
145137 @property
@@ -151,13 +143,12 @@ def intermediate_outputs(self) -> List[OutputParam]:
151143 type_hint = int ,
152144 description = "The number of denoising steps to perform at inference time" ,
153145 ),
154- OutputParam ("guidance" , type_hint = torch .Tensor , description = "Guidance scale tensor" ),
155146 ]
156147
157148 @torch .no_grad ()
158149 def __call__ (self , components : Flux2ModularPipeline , state : PipelineState ) -> PipelineState :
159150 block_state = self .get_block_state (state )
160- block_state . device = components ._execution_device
151+ device = components ._execution_device
161152
162153 scheduler = components .scheduler
163154
@@ -183,19 +174,14 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
183174 timesteps , num_inference_steps = retrieve_timesteps (
184175 scheduler ,
185176 num_inference_steps ,
186- block_state . device ,
177+ device ,
187178 timesteps = timesteps ,
188179 sigmas = sigmas ,
189180 mu = mu ,
190181 )
191182 block_state .timesteps = timesteps
192183 block_state .num_inference_steps = num_inference_steps
193184
194- batch_size = block_state .batch_size * block_state .num_images_per_prompt
195- guidance = torch .full ([1 ], block_state .guidance_scale , device = block_state .device , dtype = torch .float32 )
196- guidance = guidance .expand (batch_size )
197- block_state .guidance = guidance
198-
199185 components .scheduler .set_begin_index (0 )
200186
201187 self .set_block_state (state , block_state )
@@ -353,7 +339,61 @@ def description(self) -> str:
353339 def inputs (self ) -> List [InputParam ]:
354340 return [
355341 InputParam (name = "prompt_embeds" , required = True ),
356- InputParam (name = "latent_ids" ),
342+ ]
343+
344+ @property
345+ def intermediate_outputs (self ) -> List [OutputParam ]:
346+ return [
347+ OutputParam (
348+ name = "txt_ids" ,
349+ kwargs_type = "denoiser_input_fields" ,
350+ type_hint = torch .Tensor ,
351+ description = "4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation." ,
352+ ),
353+ ]
354+
355+ @staticmethod
356+ def _prepare_text_ids (x : torch .Tensor , t_coord : Optional [torch .Tensor ] = None ):
357+ """Prepare 4D position IDs for text tokens."""
358+ B , L , _ = x .shape
359+ out_ids = []
360+
361+ for i in range (B ):
362+ t = torch .arange (1 ) if t_coord is None else t_coord [i ]
363+ h = torch .arange (1 )
364+ w = torch .arange (1 )
365+ seq_l = torch .arange (L )
366+
367+ coords = torch .cartesian_prod (t , h , w , seq_l )
368+ out_ids .append (coords )
369+
370+ return torch .stack (out_ids )
371+
372+ def __call__ (self , components : Flux2ModularPipeline , state : PipelineState ) -> PipelineState :
373+ block_state = self .get_block_state (state )
374+
375+ prompt_embeds = block_state .prompt_embeds
376+ device = prompt_embeds .device
377+
378+ block_state .txt_ids = self ._prepare_text_ids (prompt_embeds )
379+ block_state .txt_ids = block_state .txt_ids .to (device )
380+
381+ self .set_block_state (state , block_state )
382+ return components , state
383+
384+
385+ class Flux2KleinBaseRoPEInputsStep (ModularPipelineBlocks ):
386+ model_name = "flux2-klein"
387+
388+ @property
389+ def description (self ) -> str :
390+ return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps."
391+
392+ @property
393+ def inputs (self ) -> List [InputParam ]:
394+ return [
395+ InputParam (name = "prompt_embeds" , required = True ),
396+ InputParam (name = "negative_prompt_embeds" , required = False ),
357397 ]
358398
359399 @property
@@ -366,10 +406,10 @@ def intermediate_outputs(self) -> List[OutputParam]:
366406 description = "4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation." ,
367407 ),
368408 OutputParam (
369- name = "latent_ids " ,
409+ name = "negative_txt_ids " ,
370410 kwargs_type = "denoiser_input_fields" ,
371411 type_hint = torch .Tensor ,
372- description = "4D position IDs (T, H, W, L) for image latents , used for RoPE calculation." ,
412+ description = "4D position IDs (T, H, W, L) for negative text tokens , used for RoPE calculation." ,
373413 ),
374414 ]
375415
@@ -399,6 +439,11 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
399439 block_state .txt_ids = self ._prepare_text_ids (prompt_embeds )
400440 block_state .txt_ids = block_state .txt_ids .to (device )
401441
442+ block_state .negative_txt_ids = None
443+ if block_state .negative_prompt_embeds is not None :
444+ block_state .negative_txt_ids = self ._prepare_text_ids (block_state .negative_prompt_embeds )
445+ block_state .negative_txt_ids = block_state .negative_txt_ids .to (device )
446+
402447 self .set_block_state (state , block_state )
403448 return components , state
404449
@@ -506,3 +551,42 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
506551
507552 self .set_block_state (state , block_state )
508553 return components , state
554+
555+
556+ class Flux2PrepareGuidanceStep (ModularPipelineBlocks ):
557+ model_name = "flux2"
558+
559+ @property
560+ def description (self ) -> str :
561+ return "Step that prepares the guidance scale tensor for Flux2 inference"
562+
563+ @property
564+ def inputs (self ) -> List [InputParam ]:
565+ return [
566+ InputParam ("guidance_scale" , default = 4.0 ),
567+ InputParam ("num_images_per_prompt" , default = 1 ),
568+ InputParam (
569+ "batch_size" ,
570+ required = True ,
571+ type_hint = int ,
572+ description = "Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`." ,
573+ ),
574+ ]
575+
576+ @property
577+ def intermediate_outputs (self ) -> List [OutputParam ]:
578+ return [
579+ OutputParam ("guidance" , type_hint = torch .Tensor , description = "Guidance scale tensor" ),
580+ ]
581+
582+ @torch .no_grad ()
583+ def __call__ (self , components : Flux2ModularPipeline , state : PipelineState ) -> PipelineState :
584+ block_state = self .get_block_state (state )
585+ device = components ._execution_device
586+ batch_size = block_state .batch_size * block_state .num_images_per_prompt
587+ guidance = torch .full ([1 ], block_state .guidance_scale , device = device , dtype = torch .float32 )
588+ guidance = guidance .expand (batch_size )
589+ block_state .guidance = guidance
590+
591+ self .set_block_state (state , block_state )
592+ return components , state
0 commit comments