Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 94 additions & 21 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...image_processor import IPAdapterMaskProcessor
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
Expand Down Expand Up @@ -244,28 +245,100 @@ def __call__(
# IP-adapter
ip_attn_output = torch.zeros_like(hidden_states)

for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, list):
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
)
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if mask is None:
continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
num_ip_images = 1 if ip_state.ndim == 3 else ip_state.shape[1]
if mask.shape[1] != num_ip_images:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({num_ip_images}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(self.scale)

for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)

current_ip_hidden_states = dispatch_attention_fn(
ip_query,
ip_key,
ip_value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
ip_attn_output += scale * current_ip_hidden_states
if mask is not None:
if current_ip_hidden_states.ndim == 3:
current_ip_hidden_states = current_ip_hidden_states[:, None, :, :]
if not isinstance(scale, list):
scale = [scale] * mask.shape[1]

for i in range(mask.shape[1]):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])

ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)

_current_ip_hidden_states = dispatch_attention_fn(
ip_query,
ip_key,
ip_value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
_current_ip_hidden_states = _current_ip_hidden_states.reshape(
batch_size, -1, attn.heads * attn.head_dim
)
_current_ip_hidden_states = _current_ip_hidden_states.to(ip_query.dtype)

mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
mask_downsample = mask_downsample.to(dtype=ip_query.dtype, device=ip_query.device)

ip_attn_output += scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)

current_ip_hidden_states = dispatch_attention_fn(
ip_query,
ip_key,
ip_value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
current_ip_hidden_states = current_ip_hidden_states.reshape(
batch_size, -1, attn.heads * attn.head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
ip_attn_output += scale * current_ip_hidden_states

return hidden_states, encoder_hidden_states, ip_attn_output
else:
Expand Down
14 changes: 6 additions & 8 deletions src/diffusers/modular_pipelines/flux/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import numpy as np
import torch

from ...pipelines import FluxPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
from .pipeline_helpers import pack_latents, prepare_latent_image_ids


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -390,7 +390,7 @@ def prepare_latents(

# TODO: move packing latents code to a patchifier similar to Qwen
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)
latents = pack_latents(latents, batch_size, num_channels_latents, height, width)

return latents

Expand Down Expand Up @@ -470,7 +470,7 @@ def intermediate_outputs(self) -> list[OutputParam]:
def check_inputs(image_latents, latents):
if image_latents.shape[0] != latents.shape[0]:
raise ValueError(
f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
f"`image_latents` must have the same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
)

if image_latents.ndim != 3:
Expand Down Expand Up @@ -541,7 +541,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip

height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
block_state.img_ids = prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)

self.set_block_state(state, block_state)

Expand Down Expand Up @@ -598,15 +598,13 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
):
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
img_ids = FluxPipeline._prepare_latent_image_ids(
None, image_latent_height // 2, image_latent_width // 2, device, dtype
)
img_ids = prepare_latent_image_ids(None, image_latent_height // 2, image_latent_width // 2, device, dtype)
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
img_ids[..., 0] = 1

height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
latent_ids = prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)

if img_ids is not None:
latent_ids = torch.cat([latent_ids, img_ids], dim=0)
Expand Down
19 changes: 2 additions & 17 deletions src/diffusers/modular_pipelines/flux/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,12 @@
from ...video_processor import VaeImageProcessor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .pipeline_helpers import unpack_latents


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape

# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

return latents


class FluxDecodeStep(ModularPipelineBlocks):
model_name = "flux"

Expand Down Expand Up @@ -95,7 +80,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState:

if not block_state.output_type == "latent":
latents = block_state.latents
latents = _unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor)
latents = unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
block_state.images = vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/modular_pipelines/flux/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
from .pipeline_helpers import PREFERRED_KONTEXT_RESOLUTIONS


if is_ftfy_available():
Expand Down Expand Up @@ -170,8 +171,6 @@ def intermediate_outputs(self) -> list[OutputParam]:

@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState):
from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS

block_state = self.get_block_state(state)
images = block_state.image

Expand Down
9 changes: 3 additions & 6 deletions src/diffusers/modular_pipelines/flux/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@

import torch

from ...pipelines import FluxPipeline
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import InputParam, OutputParam

# TODO: consider making these common utilities for modular if they are not pipeline-specific.
from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size
from .modular_pipeline import FluxModularPipeline
from .pipeline_helpers import calculate_dimension_from_latents, pack_latents, repeat_tensor_to_batch_size


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -209,7 +206,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
# 2. Patchify the image latent tensor
# TODO: Implement patchifier for Flux.
latent_height, latent_width = image_latent_tensor.shape[2:]
image_latent_tensor = FluxPipeline._pack_latents(
image_latent_tensor = pack_latents(
image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
)

Expand Down Expand Up @@ -266,7 +263,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
# 2. Patchify the image latent tensor
# TODO: Implement patchifier for Flux.
latent_height, latent_width = image_latent_tensor.shape[2:]
image_latent_tensor = FluxPipeline._pack_latents(
image_latent_tensor = pack_latents(
image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
)

Expand Down
110 changes: 110 additions & 0 deletions src/diffusers/modular_pipelines/flux/pipeline_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch


PREFERRED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]


# Copied from diffusers.pipelines.flux.pipeline_flux
def prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)

return latent_image_ids.to(device=device, dtype=dtype)


# Copied from diffusers.pipelines.flux.pipeline_flux
def pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

return latents


# Copied from diffusers.pipelines.flux.pipeline_flux
def unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape

# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))

latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

return latents


# Copied from diffusers.modular_pipelines.qwenimage.inputs.calculate_dimension_from_latents
def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> tuple[int, int]:
if latents.ndim != 4 and latents.ndim != 5:
raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}")

latent_height, latent_width = latents.shape[-2:]

height = latent_height * vae_scale_factor
width = latent_width * vae_scale_factor

return height, width


# Copied from diffusers.modular_pipelines.qwenimage.inputs.repeat_tensor_to_batch_size
def repeat_tensor_to_batch_size(
input_name: str,
input_tensor: torch.Tensor,
batch_size: int,
num_images_per_prompt: int = 1,
) -> torch.Tensor:
if not isinstance(input_tensor, torch.Tensor):
raise ValueError(f"`{input_name}` must be a tensor")

if input_tensor.shape[0] == 1:
repeat_by = batch_size * num_images_per_prompt
elif input_tensor.shape[0] == batch_size:
repeat_by = num_images_per_prompt
else:
raise ValueError(f"`{input_name}` must have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}")

return input_tensor.repeat_interleave(repeat_by, dim=0)
Loading
Loading