Skip to content

Commit 44a54e7

Browse files
committed
Fix Flux pipeline validation and modular helpers
1 parent 48f39c2 commit 44a54e7

22 files changed

Lines changed: 607 additions & 71 deletions

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 94 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch.nn.functional as F
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
24+
from ...image_processor import IPAdapterMaskProcessor
2425
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
2526
from ...utils import apply_lora_scale, logging
2627
from ...utils.torch_utils import maybe_allow_in_graph
@@ -244,28 +245,100 @@ def __call__(
244245
# IP-adapter
245246
ip_attn_output = torch.zeros_like(hidden_states)
246247

247-
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
248-
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
248+
if ip_adapter_masks is not None:
249+
if not isinstance(ip_adapter_masks, list):
250+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
251+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
252+
raise ValueError(
253+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
254+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
255+
f"({len(ip_hidden_states)})"
256+
)
257+
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
258+
if mask is None:
259+
continue
260+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
261+
raise ValueError(
262+
"Each element of the ip_adapter_masks array should be a tensor with shape "
263+
"[1, num_images_for_ip_adapter, height, width]."
264+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
265+
)
266+
num_ip_images = 1 if ip_state.ndim == 3 else ip_state.shape[1]
267+
if mask.shape[1] != num_ip_images:
268+
raise ValueError(
269+
f"Number of masks ({mask.shape[1]}) does not match "
270+
f"number of ip images ({num_ip_images}) at index {index}"
271+
)
272+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
273+
raise ValueError(
274+
f"Number of masks ({mask.shape[1]}) does not match "
275+
f"number of scales ({len(scale)}) at index {index}"
276+
)
277+
else:
278+
ip_adapter_masks = [None] * len(self.scale)
279+
280+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
281+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
249282
):
250-
ip_key = to_k_ip(current_ip_hidden_states)
251-
ip_value = to_v_ip(current_ip_hidden_states)
252-
253-
ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
254-
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
255-
256-
current_ip_hidden_states = dispatch_attention_fn(
257-
ip_query,
258-
ip_key,
259-
ip_value,
260-
attn_mask=None,
261-
dropout_p=0.0,
262-
is_causal=False,
263-
backend=self._attention_backend,
264-
parallel_config=self._parallel_config,
265-
)
266-
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
267-
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
268-
ip_attn_output += scale * current_ip_hidden_states
283+
if mask is not None:
284+
if current_ip_hidden_states.ndim == 3:
285+
current_ip_hidden_states = current_ip_hidden_states[:, None, :, :]
286+
if not isinstance(scale, list):
287+
scale = [scale] * mask.shape[1]
288+
289+
for i in range(mask.shape[1]):
290+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
291+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
292+
293+
ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
294+
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
295+
296+
_current_ip_hidden_states = dispatch_attention_fn(
297+
ip_query,
298+
ip_key,
299+
ip_value,
300+
attn_mask=None,
301+
dropout_p=0.0,
302+
is_causal=False,
303+
backend=self._attention_backend,
304+
parallel_config=self._parallel_config,
305+
)
306+
_current_ip_hidden_states = _current_ip_hidden_states.reshape(
307+
batch_size, -1, attn.heads * attn.head_dim
308+
)
309+
_current_ip_hidden_states = _current_ip_hidden_states.to(ip_query.dtype)
310+
311+
mask_downsample = IPAdapterMaskProcessor.downsample(
312+
mask[:, i, :, :],
313+
batch_size,
314+
_current_ip_hidden_states.shape[1],
315+
_current_ip_hidden_states.shape[2],
316+
)
317+
mask_downsample = mask_downsample.to(dtype=ip_query.dtype, device=ip_query.device)
318+
319+
ip_attn_output += scale[i] * (_current_ip_hidden_states * mask_downsample)
320+
else:
321+
ip_key = to_k_ip(current_ip_hidden_states)
322+
ip_value = to_v_ip(current_ip_hidden_states)
323+
324+
ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
325+
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
326+
327+
current_ip_hidden_states = dispatch_attention_fn(
328+
ip_query,
329+
ip_key,
330+
ip_value,
331+
attn_mask=None,
332+
dropout_p=0.0,
333+
is_causal=False,
334+
backend=self._attention_backend,
335+
parallel_config=self._parallel_config,
336+
)
337+
current_ip_hidden_states = current_ip_hidden_states.reshape(
338+
batch_size, -1, attn.heads * attn.head_dim
339+
)
340+
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
341+
ip_attn_output += scale * current_ip_hidden_states
269342

270343
return hidden_states, encoder_hidden_states, ip_attn_output
271344
else:

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
import numpy as np
1818
import torch
1919

20-
from ...pipelines import FluxPipeline
2120
from ...schedulers import FlowMatchEulerDiscreteScheduler
2221
from ...utils import logging
2322
from ...utils.torch_utils import randn_tensor
2423
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2524
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
2625
from .modular_pipeline import FluxModularPipeline
26+
from .pipeline_helpers import pack_latents, prepare_latent_image_ids
2727

2828

2929
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -390,7 +390,7 @@ def prepare_latents(
390390

391391
# TODO: move packing latents code to a patchifier similar to Qwen
392392
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
393-
latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)
393+
latents = pack_latents(latents, batch_size, num_channels_latents, height, width)
394394

395395
return latents
396396

@@ -470,7 +470,7 @@ def intermediate_outputs(self) -> list[OutputParam]:
470470
def check_inputs(image_latents, latents):
471471
if image_latents.shape[0] != latents.shape[0]:
472472
raise ValueError(
473-
f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
473+
f"`image_latents` must have the same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
474474
)
475475

476476
if image_latents.ndim != 3:
@@ -541,7 +541,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
541541

542542
height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
543543
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
544-
block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
544+
block_state.img_ids = prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
545545

546546
self.set_block_state(state, block_state)
547547

@@ -598,15 +598,13 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
598598
):
599599
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
600600
image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
601-
img_ids = FluxPipeline._prepare_latent_image_ids(
602-
None, image_latent_height // 2, image_latent_width // 2, device, dtype
603-
)
601+
img_ids = prepare_latent_image_ids(None, image_latent_height // 2, image_latent_width // 2, device, dtype)
604602
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
605603
img_ids[..., 0] = 1
606604

607605
height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
608606
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
609-
latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
607+
latent_ids = prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
610608

611609
if img_ids is not None:
612610
latent_ids = torch.cat([latent_ids, img_ids], dim=0)

src/diffusers/modular_pipelines/flux/decoders.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,12 @@
2424
from ...video_processor import VaeImageProcessor
2525
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2626
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
27+
from .pipeline_helpers import unpack_latents
2728

2829

2930
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3031

3132

32-
def _unpack_latents(latents, height, width, vae_scale_factor):
33-
batch_size, num_patches, channels = latents.shape
34-
35-
# VAE applies 8x compression on images but we must also account for packing which requires
36-
# latent height and width to be divisible by 2.
37-
height = 2 * (int(height) // (vae_scale_factor * 2))
38-
width = 2 * (int(width) // (vae_scale_factor * 2))
39-
40-
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
41-
latents = latents.permute(0, 3, 1, 4, 2, 5)
42-
43-
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
44-
45-
return latents
46-
47-
4833
class FluxDecodeStep(ModularPipelineBlocks):
4934
model_name = "flux"
5035

@@ -95,7 +80,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
9580

9681
if not block_state.output_type == "latent":
9782
latents = block_state.latents
98-
latents = _unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor)
83+
latents = unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor)
9984
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
10085
block_state.images = vae.decode(latents, return_dict=False)[0]
10186
block_state.images = components.image_processor.postprocess(

src/diffusers/modular_pipelines/flux/encoders.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2727
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
2828
from .modular_pipeline import FluxModularPipeline
29+
from .pipeline_helpers import PREFERRED_KONTEXT_RESOLUTIONS
2930

3031

3132
if is_ftfy_available():
@@ -170,8 +171,6 @@ def intermediate_outputs(self) -> list[OutputParam]:
170171

171172
@torch.no_grad()
172173
def __call__(self, components: FluxModularPipeline, state: PipelineState):
173-
from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
174-
175174
block_state = self.get_block_state(state)
176175
images = block_state.image
177176

src/diffusers/modular_pipelines/flux/inputs.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,11 @@
1515

1616
import torch
1717

18-
from ...pipelines import FluxPipeline
1918
from ...utils import logging
2019
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
2120
from ..modular_pipeline_utils import InputParam, OutputParam
22-
23-
# TODO: consider making these common utilities for modular if they are not pipeline-specific.
24-
from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size
2521
from .modular_pipeline import FluxModularPipeline
22+
from .pipeline_helpers import calculate_dimension_from_latents, pack_latents, repeat_tensor_to_batch_size
2623

2724

2825
logger = logging.get_logger(__name__)
@@ -209,7 +206,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
209206
# 2. Patchify the image latent tensor
210207
# TODO: Implement patchifier for Flux.
211208
latent_height, latent_width = image_latent_tensor.shape[2:]
212-
image_latent_tensor = FluxPipeline._pack_latents(
209+
image_latent_tensor = pack_latents(
213210
image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
214211
)
215212

@@ -266,7 +263,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
266263
# 2. Patchify the image latent tensor
267264
# TODO: Implement patchifier for Flux.
268265
latent_height, latent_width = image_latent_tensor.shape[2:]
269-
image_latent_tensor = FluxPipeline._pack_latents(
266+
image_latent_tensor = pack_latents(
270267
image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
271268
)
272269

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
17+
18+
PREFERRED_KONTEXT_RESOLUTIONS = [
19+
(672, 1568),
20+
(688, 1504),
21+
(720, 1456),
22+
(752, 1392),
23+
(800, 1328),
24+
(832, 1248),
25+
(880, 1184),
26+
(944, 1104),
27+
(1024, 1024),
28+
(1104, 944),
29+
(1184, 880),
30+
(1248, 832),
31+
(1328, 800),
32+
(1392, 752),
33+
(1456, 720),
34+
(1504, 688),
35+
(1568, 672),
36+
]
37+
38+
39+
# Copied from diffusers.pipelines.flux.pipeline_flux
40+
def prepare_latent_image_ids(batch_size, height, width, device, dtype):
41+
latent_image_ids = torch.zeros(height, width, 3)
42+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
43+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
44+
45+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
46+
47+
latent_image_ids = latent_image_ids.reshape(
48+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
49+
)
50+
51+
return latent_image_ids.to(device=device, dtype=dtype)
52+
53+
54+
# Copied from diffusers.pipelines.flux.pipeline_flux
55+
def pack_latents(latents, batch_size, num_channels_latents, height, width):
56+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
57+
latents = latents.permute(0, 2, 4, 1, 3, 5)
58+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
59+
60+
return latents
61+
62+
63+
# Copied from diffusers.pipelines.flux.pipeline_flux
64+
def unpack_latents(latents, height, width, vae_scale_factor):
65+
batch_size, num_patches, channels = latents.shape
66+
67+
# VAE applies 8x compression on images but we must also account for packing which requires
68+
# latent height and width to be divisible by 2.
69+
height = 2 * (int(height) // (vae_scale_factor * 2))
70+
width = 2 * (int(width) // (vae_scale_factor * 2))
71+
72+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
73+
latents = latents.permute(0, 3, 1, 4, 2, 5)
74+
75+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
76+
77+
return latents
78+
79+
80+
# Copied from diffusers.modular_pipelines.qwenimage.inputs.calculate_dimension_from_latents
81+
def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> tuple[int, int]:
82+
if latents.ndim != 4 and latents.ndim != 5:
83+
raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}")
84+
85+
latent_height, latent_width = latents.shape[-2:]
86+
87+
height = latent_height * vae_scale_factor
88+
width = latent_width * vae_scale_factor
89+
90+
return height, width
91+
92+
93+
# Copied from diffusers.modular_pipelines.qwenimage.inputs.repeat_tensor_to_batch_size
94+
def repeat_tensor_to_batch_size(
95+
input_name: str,
96+
input_tensor: torch.Tensor,
97+
batch_size: int,
98+
num_images_per_prompt: int = 1,
99+
) -> torch.Tensor:
100+
if not isinstance(input_tensor, torch.Tensor):
101+
raise ValueError(f"`{input_name}` must be a tensor")
102+
103+
if input_tensor.shape[0] == 1:
104+
repeat_by = batch_size * num_images_per_prompt
105+
elif input_tensor.shape[0] == batch_size:
106+
repeat_by = num_images_per_prompt
107+
else:
108+
raise ValueError(f"`{input_name}` must have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}")
109+
110+
return input_tensor.repeat_interleave(repeat_by, dim=0)

0 commit comments

Comments
 (0)