Skip to content

Commit b0f50c6

Browse files
committed
update
1 parent 921b959 commit b0f50c6

5 files changed

Lines changed: 33 additions & 183 deletions

File tree

src/diffusers/modular_pipelines/flux2/encoders.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -399,22 +399,22 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
399399
condition_images = block_state.condition_images
400400

401401
if condition_images is None:
402-
block_state.image_latents = None
403-
else:
404-
device = components._execution_device
405-
dtype = components.vae.dtype
406-
407-
image_latents = []
408-
for image in condition_images:
409-
image = image.to(device=device, dtype=dtype)
410-
latent = self._encode_vae_image(
411-
vae=components.vae,
412-
image=image,
413-
generator=block_state.generator,
414-
)
415-
image_latents.append(latent)
402+
return components, state
403+
404+
device = components._execution_device
405+
dtype = components.vae.dtype
406+
407+
image_latents = []
408+
for image in condition_images:
409+
image = image.to(device=device, dtype=dtype)
410+
latent = self._encode_vae_image(
411+
vae=components.vae,
412+
image=image,
413+
generator=block_state.generator,
414+
)
415+
image_latents.append(latent)
416416

417-
block_state.image_latents = image_latents
417+
block_state.image_latents = image_latents
418418

419419
self.set_block_state(state, block_state)
420420
return components, state

src/diffusers/modular_pipelines/flux2/inputs.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class Flux2TextInputStep(ModularPipelineBlocks):
3131
@property
3232
def description(self) -> str:
3333
return (
34-
"Text input processing step that standardizes text embeddings for Flux2 pipeline.\n"
3534
"This step:\n"
3635
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
3736
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
@@ -86,55 +85,3 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
8685

8786
self.set_block_state(state, block_state)
8887
return components, state
89-
90-
91-
class Flux2ImageInputStep(ModularPipelineBlocks):
92-
model_name = "flux2"
93-
94-
@property
95-
def description(self) -> str:
96-
return (
97-
"Image input processing step that prepares image latents for Flux2 conditioning.\n"
98-
"This step expands image latents to match the batch size."
99-
)
100-
101-
@property
102-
def inputs(self) -> List[InputParam]:
103-
return [
104-
InputParam("num_images_per_prompt", default=1),
105-
InputParam("batch_size", required=True, type_hint=int),
106-
InputParam("image_latents", type_hint=torch.Tensor),
107-
InputParam("image_latent_ids", type_hint=torch.Tensor),
108-
]
109-
110-
@property
111-
def intermediate_outputs(self) -> List[OutputParam]:
112-
return [
113-
OutputParam(
114-
"image_latents",
115-
type_hint=torch.Tensor,
116-
description="Packed image latents expanded to batch size",
117-
),
118-
OutputParam(
119-
"image_latent_ids",
120-
type_hint=torch.Tensor,
121-
description="Image latent position IDs expanded to batch size",
122-
),
123-
]
124-
125-
@torch.no_grad()
126-
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
127-
block_state = self.get_block_state(state)
128-
129-
image_latents = block_state.image_latents
130-
image_latent_ids = block_state.image_latent_ids
131-
target_batch_size = block_state.batch_size * block_state.num_images_per_prompt
132-
133-
if image_latents is not None:
134-
block_state.image_latents = image_latents.repeat(target_batch_size, 1, 1)
135-
136-
if image_latent_ids is not None:
137-
block_state.image_latent_ids = image_latent_ids.repeat(target_batch_size, 1, 1)
138-
139-
self.set_block_state(state, block_state)
140-
return components, state

src/diffusers/modular_pipelines/flux2/modular_blocks.py

Lines changed: 17 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -30,38 +30,10 @@
3030
Flux2VaeEncoderStep,
3131
)
3232
from .inputs import (
33-
Flux2ImageInputStep,
3433
Flux2TextInputStep,
3534
)
3635

3736

38-
class Flux2AutoTextInputStep(AutoPipelineBlocks):
39-
block_classes = [Flux2TextInputStep]
40-
block_names = ["text_input"]
41-
block_trigger_inputs = [None]
42-
43-
@property
44-
def description(self):
45-
return (
46-
"Text input step that processes text embeddings and determines batch size.\n"
47-
" - `Flux2TextInputStep` is always used."
48-
)
49-
50-
51-
class Flux2AutoImageInputStep(AutoPipelineBlocks):
52-
block_classes = [Flux2ImageInputStep]
53-
block_names = ["image_input"]
54-
block_trigger_inputs = ["image_latents"]
55-
56-
@property
57-
def description(self):
58-
return (
59-
"Image input step that expands image latents to match batch size.\n"
60-
" - `Flux2ImageInputStep` is used when `image_latents` is provided.\n"
61-
" - Skipped when no image conditioning is used."
62-
)
63-
64-
6537
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
6638

6739

@@ -100,21 +72,6 @@ def description(self):
10072
)
10173

10274

103-
class Flux2AutoTextEncoderStep(AutoPipelineBlocks):
104-
block_classes = [Flux2RemoteTextEncoderStep, Flux2TextEncoderStep]
105-
block_names = ["remote", "local"]
106-
block_trigger_inputs = ["remote_text_encoder", None]
107-
108-
@property
109-
def description(self):
110-
return (
111-
"Text encoder step that generates text embeddings to guide the image generation.\n"
112-
"This is an auto pipeline block that selects between local and remote text encoding.\n"
113-
" - `Flux2RemoteTextEncoderStep` is used when `remote_text_encoder=True`.\n"
114-
" - `Flux2TextEncoderStep` is used otherwise (default)."
115-
)
116-
117-
11875
Flux2BeforeDenoiseBlocks = InsertableDict(
11976
[
12077
("prepare_latents", Flux2PrepareLatentsStep()),
@@ -135,53 +92,25 @@ def description(self):
13592
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
13693

13794

138-
class Flux2AutoBeforeDenoiseStep(AutoPipelineBlocks):
139-
model_name = "flux2"
140-
block_classes = [Flux2BeforeDenoiseStep]
141-
block_names = ["before_denoise"]
142-
block_trigger_inputs = [None]
143-
144-
@property
145-
def description(self):
146-
return (
147-
"Before denoise step that prepares the inputs for the denoise step.\n"
148-
"This is an auto pipeline block for Flux2.\n"
149-
" - `Flux2BeforeDenoiseStep` is used for both text-to-image and image-conditioned generation."
150-
)
151-
152-
153-
class Flux2AutoDenoiseStep(AutoPipelineBlocks):
154-
block_classes = [Flux2DenoiseStep]
155-
block_names = ["denoise"]
156-
block_trigger_inputs = [None]
157-
158-
@property
159-
def description(self) -> str:
160-
return (
161-
"Denoise step that iteratively denoises the latents. "
162-
"This is an auto pipeline block that works for Flux2 text-to-image and image-conditioned tasks."
163-
" - `Flux2DenoiseStep` (denoise) for text-to-image and image-conditioned tasks."
164-
)
165-
166-
167-
class Flux2AutoDecodeStep(AutoPipelineBlocks):
168-
block_classes = [Flux2DecodeStep]
169-
block_names = ["decode"]
170-
block_trigger_inputs = [None]
171-
172-
@property
173-
def description(self):
174-
return "Decode step that decodes the denoised latents into image outputs.\n - `Flux2DecodeStep`"
95+
AUTO_BLOCKS = InsertableDict(
96+
[
97+
("text_encoder", Flux2TextEncoderStep()),
98+
("text_input", Flux2TextInputStep()),
99+
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
100+
("before_denoise", Flux2BeforeDenoiseStep()),
101+
("denoise", Flux2DenoiseStep()),
102+
("decode", Flux2DecodeStep()),
103+
]
104+
)
175105

176106

177-
AUTO_BLOCKS = InsertableDict(
107+
REMOTE_AUTO_BLOCKS = InsertableDict(
178108
[
179-
("text_encoder", Flux2AutoTextEncoderStep()),
180-
("text_input", Flux2AutoTextInputStep()),
181-
("image_encoder", Flux2AutoVaeEncoderStep()),
182-
("image_input", Flux2AutoImageInputStep()),
183-
("before_denoise", Flux2AutoBeforeDenoiseStep()),
184-
("denoise", Flux2AutoDenoiseStep()),
109+
("text_encoder", Flux2RemoteTextEncoderStep()),
110+
("text_input", Flux2TextInputStep()),
111+
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
112+
("before_denoise", Flux2BeforeDenoiseStep()),
113+
("denoise", Flux2DenoiseStep()),
185114
("decode", Flux2DecodeStep()),
186115
]
187116
)
@@ -221,7 +150,6 @@ def description(self):
221150
("preprocess_images", Flux2ProcessImagesInputStep()),
222151
("vae_encoder", Flux2VaeEncoderStep()),
223152
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
224-
("image_input", Flux2ImageInputStep()),
225153
("prepare_latents", Flux2PrepareLatentsStep()),
226154
("set_timesteps", Flux2SetTimestepsStep()),
227155
("prepare_rope_inputs", Flux2RoPEInputsStep()),
@@ -234,4 +162,5 @@ def description(self):
234162
"text2image": TEXT2IMAGE_BLOCKS,
235163
"image_conditioned": IMAGE_CONDITIONED_BLOCKS,
236164
"auto": AUTO_BLOCKS,
165+
"remote": REMOTE_AUTO_BLOCKS,
237166
}

tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import random
1717
import tempfile
18-
import unittest
1918

2019
import numpy as np
2120
import PIL
@@ -26,11 +25,6 @@
2625
Flux2ModularPipeline,
2726
ModularPipeline,
2827
)
29-
from diffusers.modular_pipelines.flux2 import (
30-
Flux2AutoTextEncoderStep,
31-
Flux2RemoteTextEncoderStep,
32-
Flux2TextEncoderStep,
33-
)
3428

3529
from ...testing_utils import floats_tensor, torch_device
3630
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
@@ -114,26 +108,7 @@ def test_save_from_pretrained(self):
114108

115109
image_slices.append(image[0, -3:, -3:, -1].flatten())
116110

117-
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
111+
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-5
118112

119113
def test_float16_inference(self):
120114
super().test_float16_inference(9e-2)
121-
122-
123-
class TestFlux2AutoTextEncoderStep(unittest.TestCase):
124-
def test_auto_text_encoder_block_classes(self):
125-
auto_step = Flux2AutoTextEncoderStep()
126-
127-
assert len(auto_step.block_classes) == 2
128-
assert Flux2RemoteTextEncoderStep in auto_step.block_classes
129-
assert Flux2TextEncoderStep in auto_step.block_classes
130-
131-
def test_auto_text_encoder_trigger_inputs(self):
132-
auto_step = Flux2AutoTextEncoderStep()
133-
134-
assert auto_step.block_trigger_inputs == ["remote_text_encoder", None]
135-
136-
def test_auto_text_encoder_block_names(self):
137-
auto_step = Flux2AutoTextEncoderStep()
138-
139-
assert auto_step.block_names == ["remote", "local"]

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ def test_inference_batch_single_identical(
165165
expected_max_diff=1e-4,
166166
):
167167
pipe = self.get_pipeline().to(torch_device)
168-
169168
inputs = self.get_dummy_inputs()
170169

171170
# Reset generator in case it is has been used in self.get_dummy_inputs

0 commit comments

Comments
 (0)