Skip to content

Commit 53d8a1e

Browse files
yiyixuxuyiyi@huggingface.cosayakpaulasomoza
authored
[modular]support klein (#13002)
* support klein * style * copies * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com> * Update src/diffusers/modular_pipelines/flux2/encoders.py * a few fix: unpack latents before decoder etc * style * remove guidannce to its own block * style * flux2-dev work in modular setting * up * up up * add tests --------- Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
1 parent d54669a commit 53d8a1e

15 files changed

+1408
-111
lines changed

src/diffusers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,9 @@
413413
_import_structure["modular_pipelines"].extend(
414414
[
415415
"Flux2AutoBlocks",
416+
"Flux2KleinAutoBlocks",
417+
"Flux2KleinBaseAutoBlocks",
418+
"Flux2KleinModularPipeline",
416419
"Flux2ModularPipeline",
417420
"FluxAutoBlocks",
418421
"FluxKontextAutoBlocks",
@@ -1146,6 +1149,9 @@
11461149
else:
11471150
from .modular_pipelines import (
11481151
Flux2AutoBlocks,
1152+
Flux2KleinAutoBlocks,
1153+
Flux2KleinBaseAutoBlocks,
1154+
Flux2KleinModularPipeline,
11491155
Flux2ModularPipeline,
11501156
FluxAutoBlocks,
11511157
FluxKontextAutoBlocks,

src/diffusers/modular_pipelines/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@
5454
]
5555
_import_structure["flux2"] = [
5656
"Flux2AutoBlocks",
57+
"Flux2KleinAutoBlocks",
58+
"Flux2KleinBaseAutoBlocks",
5759
"Flux2ModularPipeline",
60+
"Flux2KleinModularPipeline",
5861
]
5962
_import_structure["qwenimage"] = [
6063
"QwenImageAutoBlocks",
@@ -81,7 +84,13 @@
8184
else:
8285
from .components_manager import ComponentsManager
8386
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
84-
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
87+
from .flux2 import (
88+
Flux2AutoBlocks,
89+
Flux2KleinAutoBlocks,
90+
Flux2KleinBaseAutoBlocks,
91+
Flux2KleinModularPipeline,
92+
Flux2ModularPipeline,
93+
)
8594
from .modular_pipeline import (
8695
AutoPipelineBlocks,
8796
BlockState,

src/diffusers/modular_pipelines/flux2/__init__.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,19 @@
4343
"Flux2ProcessImagesInputStep",
4444
"Flux2TextInputStep",
4545
]
46-
_import_structure["modular_blocks"] = [
46+
_import_structure["modular_blocks_flux2"] = [
4747
"ALL_BLOCKS",
4848
"AUTO_BLOCKS",
4949
"REMOTE_AUTO_BLOCKS",
5050
"TEXT2IMAGE_BLOCKS",
5151
"IMAGE_CONDITIONED_BLOCKS",
5252
"Flux2AutoBlocks",
5353
"Flux2AutoVaeEncoderStep",
54-
"Flux2BeforeDenoiseStep",
54+
"Flux2CoreDenoiseStep",
5555
"Flux2VaeEncoderSequentialStep",
5656
]
57-
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
57+
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
58+
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
5859

5960
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
6061
try:
@@ -85,18 +86,22 @@
8586
Flux2ProcessImagesInputStep,
8687
Flux2TextInputStep,
8788
)
88-
from .modular_blocks import (
89+
from .modular_blocks_flux2 import (
8990
ALL_BLOCKS,
9091
AUTO_BLOCKS,
9192
IMAGE_CONDITIONED_BLOCKS,
9293
REMOTE_AUTO_BLOCKS,
9394
TEXT2IMAGE_BLOCKS,
9495
Flux2AutoBlocks,
9596
Flux2AutoVaeEncoderStep,
96-
Flux2BeforeDenoiseStep,
97+
Flux2CoreDenoiseStep,
9798
Flux2VaeEncoderSequentialStep,
9899
)
99-
from .modular_pipeline import Flux2ModularPipeline
100+
from .modular_blocks_flux2_klein import (
101+
Flux2KleinAutoBlocks,
102+
Flux2KleinBaseAutoBlocks,
103+
)
104+
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
100105
else:
101106
import sys
102107

src/diffusers/modular_pipelines/flux2/before_denoise.py

Lines changed: 103 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -129,17 +129,9 @@ def inputs(self) -> List[InputParam]:
129129
InputParam("num_inference_steps", default=50),
130130
InputParam("timesteps"),
131131
InputParam("sigmas"),
132-
InputParam("guidance_scale", default=4.0),
133132
InputParam("latents", type_hint=torch.Tensor),
134-
InputParam("num_images_per_prompt", default=1),
135133
InputParam("height", type_hint=int),
136134
InputParam("width", type_hint=int),
137-
InputParam(
138-
"batch_size",
139-
required=True,
140-
type_hint=int,
141-
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
142-
),
143135
]
144136

145137
@property
@@ -151,13 +143,12 @@ def intermediate_outputs(self) -> List[OutputParam]:
151143
type_hint=int,
152144
description="The number of denoising steps to perform at inference time",
153145
),
154-
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
155146
]
156147

157148
@torch.no_grad()
158149
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
159150
block_state = self.get_block_state(state)
160-
block_state.device = components._execution_device
151+
device = components._execution_device
161152

162153
scheduler = components.scheduler
163154

@@ -183,19 +174,14 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
183174
timesteps, num_inference_steps = retrieve_timesteps(
184175
scheduler,
185176
num_inference_steps,
186-
block_state.device,
177+
device,
187178
timesteps=timesteps,
188179
sigmas=sigmas,
189180
mu=mu,
190181
)
191182
block_state.timesteps = timesteps
192183
block_state.num_inference_steps = num_inference_steps
193184

194-
batch_size = block_state.batch_size * block_state.num_images_per_prompt
195-
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
196-
guidance = guidance.expand(batch_size)
197-
block_state.guidance = guidance
198-
199185
components.scheduler.set_begin_index(0)
200186

201187
self.set_block_state(state, block_state)
@@ -353,7 +339,61 @@ def description(self) -> str:
353339
def inputs(self) -> List[InputParam]:
354340
return [
355341
InputParam(name="prompt_embeds", required=True),
356-
InputParam(name="latent_ids"),
342+
]
343+
344+
@property
345+
def intermediate_outputs(self) -> List[OutputParam]:
346+
return [
347+
OutputParam(
348+
name="txt_ids",
349+
kwargs_type="denoiser_input_fields",
350+
type_hint=torch.Tensor,
351+
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
352+
),
353+
]
354+
355+
@staticmethod
356+
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
357+
"""Prepare 4D position IDs for text tokens."""
358+
B, L, _ = x.shape
359+
out_ids = []
360+
361+
for i in range(B):
362+
t = torch.arange(1) if t_coord is None else t_coord[i]
363+
h = torch.arange(1)
364+
w = torch.arange(1)
365+
seq_l = torch.arange(L)
366+
367+
coords = torch.cartesian_prod(t, h, w, seq_l)
368+
out_ids.append(coords)
369+
370+
return torch.stack(out_ids)
371+
372+
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
373+
block_state = self.get_block_state(state)
374+
375+
prompt_embeds = block_state.prompt_embeds
376+
device = prompt_embeds.device
377+
378+
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
379+
block_state.txt_ids = block_state.txt_ids.to(device)
380+
381+
self.set_block_state(state, block_state)
382+
return components, state
383+
384+
385+
class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks):
386+
model_name = "flux2-klein"
387+
388+
@property
389+
def description(self) -> str:
390+
return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps."
391+
392+
@property
393+
def inputs(self) -> List[InputParam]:
394+
return [
395+
InputParam(name="prompt_embeds", required=True),
396+
InputParam(name="negative_prompt_embeds", required=False),
357397
]
358398

359399
@property
@@ -366,10 +406,10 @@ def intermediate_outputs(self) -> List[OutputParam]:
366406
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
367407
),
368408
OutputParam(
369-
name="latent_ids",
409+
name="negative_txt_ids",
370410
kwargs_type="denoiser_input_fields",
371411
type_hint=torch.Tensor,
372-
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
412+
description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.",
373413
),
374414
]
375415

@@ -399,6 +439,11 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
399439
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
400440
block_state.txt_ids = block_state.txt_ids.to(device)
401441

442+
block_state.negative_txt_ids = None
443+
if block_state.negative_prompt_embeds is not None:
444+
block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds)
445+
block_state.negative_txt_ids = block_state.negative_txt_ids.to(device)
446+
402447
self.set_block_state(state, block_state)
403448
return components, state
404449

@@ -506,3 +551,42 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
506551

507552
self.set_block_state(state, block_state)
508553
return components, state
554+
555+
556+
class Flux2PrepareGuidanceStep(ModularPipelineBlocks):
557+
model_name = "flux2"
558+
559+
@property
560+
def description(self) -> str:
561+
return "Step that prepares the guidance scale tensor for Flux2 inference"
562+
563+
@property
564+
def inputs(self) -> List[InputParam]:
565+
return [
566+
InputParam("guidance_scale", default=4.0),
567+
InputParam("num_images_per_prompt", default=1),
568+
InputParam(
569+
"batch_size",
570+
required=True,
571+
type_hint=int,
572+
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
573+
),
574+
]
575+
576+
@property
577+
def intermediate_outputs(self) -> List[OutputParam]:
578+
return [
579+
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
580+
]
581+
582+
@torch.no_grad()
583+
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
584+
block_state = self.get_block_state(state)
585+
device = components._execution_device
586+
batch_size = block_state.batch_size * block_state.num_images_per_prompt
587+
guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32)
588+
guidance = guidance.expand(batch_size)
589+
block_state.guidance = guidance
590+
591+
self.set_block_state(state, block_state)
592+
return components, state

0 commit comments

Comments
 (0)