Skip to content

Commit 3806a9a

Browse files
committed
update
1 parent 7587674 commit 3806a9a

3 files changed

Lines changed: 81 additions & 171 deletions

File tree

src/diffusers/modular_pipelines/flux2/encoders.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
import torch
1818
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
1919

20-
from ...configuration_utils import FrozenDict
2120
from ...models import AutoencoderKLFlux2
22-
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
2321
from ...utils import logging
2422
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2523
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -272,75 +270,6 @@ def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> Pi
272270
return components, state
273271

274272

275-
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
276-
model_name = "flux2"
277-
278-
@property
279-
def description(self) -> str:
280-
return "Image preprocess step for Flux2. Validates and preprocesses reference images."
281-
282-
@property
283-
def expected_components(self) -> List[ComponentSpec]:
284-
return [
285-
ComponentSpec(
286-
"image_processor",
287-
Flux2ImageProcessor,
288-
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
289-
default_creation_method="from_config",
290-
),
291-
]
292-
293-
@property
294-
def inputs(self) -> List[InputParam]:
295-
return [
296-
InputParam("image"),
297-
InputParam("height"),
298-
InputParam("width"),
299-
]
300-
301-
@property
302-
def intermediate_outputs(self) -> List[OutputParam]:
303-
return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])]
304-
305-
@torch.no_grad()
306-
def __call__(self, components: Flux2ModularPipeline, state: PipelineState):
307-
block_state = self.get_block_state(state)
308-
images = block_state.image
309-
310-
if images is None:
311-
block_state.condition_images = None
312-
else:
313-
if not isinstance(images, list):
314-
images = [images]
315-
316-
condition_images = []
317-
for img in images:
318-
components.image_processor.check_image_input(img)
319-
320-
image_width, image_height = img.size
321-
if image_width * image_height > 1024 * 1024:
322-
img = components.image_processor._resize_to_target_area(img, 1024 * 1024)
323-
image_width, image_height = img.size
324-
325-
multiple_of = components.vae_scale_factor * 2
326-
image_width = (image_width // multiple_of) * multiple_of
327-
image_height = (image_height // multiple_of) * multiple_of
328-
condition_img = components.image_processor.preprocess(
329-
img, height=image_height, width=image_width, resize_mode="crop"
330-
)
331-
condition_images.append(condition_img)
332-
333-
if block_state.height is None:
334-
block_state.height = image_height
335-
if block_state.width is None:
336-
block_state.width = image_width
337-
338-
block_state.condition_images = condition_images
339-
340-
self.set_block_state(state, block_state)
341-
return components, state
342-
343-
344273
class Flux2VaeEncoderStep(ModularPipelineBlocks):
345274
model_name = "flux2"
346275

Lines changed: 76 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,79 @@
1-
# copyright 2025 the huggingface team. all rights reserved.
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
22
#
3-
# licensed under the apache license, version 2.0 (the "license");
4-
# you may not use this file except in compliance with the license.
5-
# you may obtain a copy of the license at
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
66
#
7-
# http://www.apache.org/licenses/license-2.0
7+
# http://www.apache.org/licenses/LICENSE-2.0
88
#
9-
# unless required by applicable law or agreed to in writing, software
10-
# distributed under the license is distributed on an "as is" basis,
11-
# without warranties or conditions of any kind, either express or implied.
12-
# see the license for the specific language governing permissions and
13-
# limitations under the license.
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
1414

15-
from typing import list
15+
from typing import List
1616

1717
import torch
1818

19-
from ...configuration_utils import frozendict
20-
from ...pipelines.flux2.image_processor import flux2imageprocessor
19+
from ...configuration_utils import FrozenDict
20+
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
2121
from ...utils import logging
22-
from ..modular_pipeline import modularpipelineblocks, pipelinestate
23-
from ..modular_pipeline_utils import componentspec, inputparam, outputparam
24-
from .modular_pipeline import flux2modularpipeline
22+
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
23+
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
24+
from .modular_pipeline import Flux2ModularPipeline
2525

2626

2727
logger = logging.get_logger(__name__)
2828

2929

30-
class flux2textinputstep(modularpipelineblocks):
30+
class Flux2TextInputStep(ModularPipelineBlocks):
3131
model_name = "flux2"
3232

3333
@property
3434
def description(self) -> str:
3535
return (
36-
"this step:\n"
37-
" 1. determines `batch_size` and `dtype` based on `prompt_embeds`\n"
38-
" 2. ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
36+
"This step:\n"
37+
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
38+
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
3939
)
4040

4141
@property
42-
def inputs(self) -> list[inputparam]:
42+
def inputs(self) -> List[InputParam]:
4343
return [
44-
inputparam("num_images_per_prompt", default=1),
45-
inputparam(
44+
InputParam("num_images_per_prompt", default=1),
45+
InputParam(
4646
"prompt_embeds",
47-
required=true,
47+
required=True,
4848
kwargs_type="denoiser_input_fields",
49-
type_hint=torch.tensor,
50-
description="pre-generated text embeddings from mistral3. can be generated from text_encoder step.",
49+
type_hint=torch.Tensor,
50+
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
5151
),
5252
]
5353

5454
@property
55-
def intermediate_outputs(self) -> list[str]:
55+
def intermediate_outputs(self) -> List[str]:
5656
return [
57-
outputparam(
57+
OutputParam(
5858
"batch_size",
5959
type_hint=int,
60-
description="number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
60+
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
6161
),
62-
outputparam(
62+
OutputParam(
6363
"dtype",
6464
type_hint=torch.dtype,
65-
description="data type of model tensor inputs (determined by `prompt_embeds`)",
65+
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
6666
),
67-
outputparam(
67+
OutputParam(
6868
"prompt_embeds",
69-
type_hint=torch.tensor,
69+
type_hint=torch.Tensor,
7070
kwargs_type="denoiser_input_fields",
71-
description="text embeddings used to guide the image generation",
71+
description="Text embeddings used to guide the image generation",
7272
),
7373
]
7474

7575
@torch.no_grad()
76-
def __call__(self, components: flux2modularpipeline, state: pipelinestate) -> pipelinestate:
76+
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
7777
block_state = self.get_block_state(state)
7878

7979
block_state.batch_size = block_state.prompt_embeds.shape[0]
@@ -89,70 +89,72 @@ def __call__(self, components: flux2modularpipeline, state: pipelinestate) -> pi
8989
return components, state
9090

9191

92-
class flux2processimagesinputstep(modularpipelineblocks):
92+
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
9393
model_name = "flux2"
9494

9595
@property
9696
def description(self) -> str:
97-
return "image preprocess step for flux2. validates and preprocesses reference images."
97+
return "Image preprocess step for Flux2. Validates and preprocesses reference images."
9898

9999
@property
100-
def expected_components(self) -> list[componentspec]:
100+
def expected_components(self) -> List[ComponentSpec]:
101101
return [
102-
componentspec(
102+
ComponentSpec(
103103
"image_processor",
104-
flux2imageprocessor,
105-
config=frozendict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
104+
Flux2ImageProcessor,
105+
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
106106
default_creation_method="from_config",
107107
),
108108
]
109109

110110
@property
111-
def inputs(self) -> list[inputparam]:
111+
def inputs(self) -> List[InputParam]:
112112
return [
113-
inputparam("image"),
114-
inputparam("height"),
115-
inputparam("width"),
113+
InputParam("image"),
114+
InputParam("height"),
115+
InputParam("width"),
116116
]
117117

118118
@property
119-
def intermediate_outputs(self) -> list[outputparam]:
120-
return [outputparam(name="condition_images", type_hint=list[torch.tensor])]
119+
def intermediate_outputs(self) -> List[OutputParam]:
120+
return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])]
121121

122122
@torch.no_grad()
123-
def __call__(self, components: flux2modularpipeline, state: pipelinestate):
123+
def __call__(self, components: Flux2ModularPipeline, state: PipelineState):
124124
block_state = self.get_block_state(state)
125125
images = block_state.image
126126

127-
if images is none:
128-
block_state.condition_images = none
129-
else:
130-
if not isinstance(images, list):
131-
images = [images]
127+
if images is None:
128+
block_state.condition_images = None
129+
self.set_block_state(state, block_state)
130+
return components, state
132131

133-
condition_images = []
134-
for img in images:
135-
components.image_processor.check_image_input(img)
132+
if not isinstance(images, list):
133+
images = [images]
136134

135+
condition_images = []
136+
for img in images:
137+
components.image_processor.check_image_input(img)
138+
139+
image_width, image_height = img.size
140+
if image_width * image_height > 1024 * 1024:
141+
img = components.image_processor._resize_to_target_area(img, 1024 * 1024)
137142
image_width, image_height = img.size
138-
if image_width * image_height > 1024 * 1024:
139-
img = components.image_processor._resize_to_target_area(img, 1024 * 1024)
140-
image_width, image_height = img.size
141-
142-
multiple_of = components.vae_scale_factor * 2
143-
image_width = (image_width // multiple_of) * multiple_of
144-
image_height = (image_height // multiple_of) * multiple_of
145-
condition_img = components.image_processor.preprocess(
146-
img, height=image_height, width=image_width, resize_mode="crop"
147-
)
148-
condition_images.append(condition_img)
149-
150-
if block_state.height is none:
151-
block_state.height = image_height
152-
if block_state.width is none:
153-
block_state.width = image_width
154-
155-
block_state.condition_images = condition_images
143+
144+
multiple_of = components.vae_scale_factor * 2
145+
image_width = (image_width // multiple_of) * multiple_of
146+
image_height = (image_height // multiple_of) * multiple_of
147+
condition_img = components.image_processor.preprocess(
148+
img, height=image_height, width=image_width, resize_mode="crop"
149+
)
150+
condition_images.append(condition_img)
151+
152+
if block_state.height is None:
153+
block_state.height = image_height
154+
if block_state.width is None:
155+
block_state.width = image_width
156+
157+
block_state.condition_images = condition_images
156158

157159
self.set_block_state(state, block_state)
158160
return components, state

tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,14 @@
1414
# limitations under the License.
1515

1616
import random
17-
import tempfile
1817

1918
import numpy as np
2019
import PIL
21-
import torch
20+
import pytest
2221

2322
from diffusers.modular_pipelines import (
2423
Flux2AutoBlocks,
2524
Flux2ModularPipeline,
26-
ModularPipeline,
2725
)
2826

2927
from ...testing_utils import floats_tensor, torch_device
@@ -87,28 +85,9 @@ def get_dummy_inputs(self, seed=0):
8785

8886
return inputs
8987

90-
def test_save_from_pretrained(self):
91-
pipes = []
92-
base_pipe = self.get_pipeline().to(torch_device)
93-
pipes.append(base_pipe)
94-
95-
with tempfile.TemporaryDirectory() as tmpdirname:
96-
base_pipe.save_pretrained(tmpdirname)
97-
98-
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
99-
pipe.load_components(torch_dtype=torch.float32)
100-
pipe.to(torch_device)
101-
102-
pipes.append(pipe)
103-
104-
image_slices = []
105-
for pipe in pipes:
106-
inputs = self.get_dummy_inputs()
107-
image = pipe(**inputs, output="images")
108-
109-
image_slices.append(image[0, -3:, -3:, -1].flatten())
110-
111-
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-5
112-
11388
def test_float16_inference(self):
11489
super().test_float16_inference(9e-2)
90+
91+
@pytest.mark.skip(reason="batched inference is currently not supported")
92+
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
93+
return

0 commit comments

Comments
 (0)