Skip to content

chroma model/pipeline review #13619

@hlky

Description

@hlky

chroma model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Reviewed: public/lazy imports, model config/loading hooks, runtime dtype/device paths, attention masks/processors, offload paths, docs, examples, fast/slow tests, and related Flux/Qwen precedents. Top-level imports for ChromaPipeline, ChromaImg2ImgPipeline, ChromaInpaintPipeline, and ChromaTransformer2DModel work.

Issue 1: Existing Chroma float-mask bug is still present

Affected code:

# for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer
seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)

seq_lengths = tokenizer_mask_device.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)

seq_lengths = tokenizer_mask.sum(dim=1)
mask_indices = torch.arange(tokenizer_mask.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)

Problem:
The Chroma pipelines convert padding masks to prompt_embeds dtype. SDPA interprets float masks as additive bias, not keep/drop masks, so 0.0 does not mask padding. This is an exact duplicate of closed issue #12116 and related earlier issue #11724, but it is still reproducible at this commit.

Impact:
Masked T5 padding can still influence image tokens, especially for short prompts, causing quality/parity regressions.

Reproduction:

import torch
from diffusers import ChromaTransformer2DModel

torch.manual_seed(0)
model = ChromaTransformer2DModel(
    in_channels=4, out_channels=4, num_layers=1, num_single_layers=1,
    attention_head_dim=4, num_attention_heads=1, joint_attention_dim=8,
    axes_dims_rope=(0, 2, 2), approximator_num_channels=16,
    approximator_hidden_dim=8, approximator_layers=1,
).eval()

base_encoder = torch.randn(1, 3, 8)
changed_encoder = base_encoder.clone()
changed_encoder[:, 2] += 1000  # token 2 is masked

common = dict(
    hidden_states=torch.randn(1, 2, 4),
    timestep=torch.ones(1),
    txt_ids=torch.zeros(3, 3),
    img_ids=torch.zeros(2, 3),
)

for dtype in (torch.bool, torch.float32):
    mask = torch.tensor([[1, 1, 0, 1, 1]], dtype=dtype)
    with torch.no_grad():
        a = model(encoder_hidden_states=base_encoder, attention_mask=mask, **common).sample
        b = model(encoder_hidden_states=changed_encoder, attention_mask=mask, **common).sample
    print(dtype, (a - b).abs().max().item())

Relevant precedent:
attention_dispatch._normalize_attn_mask requires bool masks for mask-derived sequence lengths:

def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
"""
Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
FlashAttention/Sage varlen.
Supports 1D to 4D shapes and common broadcasting patterns.
"""
if attn_mask.dtype != torch.bool:
raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")

Suggested fix:

# In all Chroma _get_t5_prompt_embeds methods:
attention_mask = mask_indices <= seq_lengths.unsqueeze(1)

# In all Chroma _prepare_attention_mask methods:
attention_mask = attention_mask.to(dtype=torch.bool)
image_attention_mask = torch.ones(
    batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool
)
attention_mask = torch.cat([attention_mask, image_attention_mask], dim=1)

Issue 2: ChromaInpaintPipeline crashes on missing guidance_embeds

Affected code:

# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None

@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: int | None = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
axes_dims_rope: tuple[int, ...] = (16, 56, 56),
approximator_num_channels: int = 64,
approximator_hidden_dim: int = 5120,
approximator_layers: int = 5,
):

Problem:
The inpaint pipeline copied Flux guidance-embedding handling, but ChromaTransformer2DModel does not register guidance_embeds and does not accept a guidance forward argument. Accessing self.transformer.config.guidance_embeds raises AttributeError.

Impact:
ChromaInpaintPipeline.__call__ fails before denoising with the standard Chroma transformer.

Reproduction:

from diffusers import ChromaTransformer2DModel

transformer = ChromaTransformer2DModel(
    in_channels=4, out_channels=4, num_layers=0, num_single_layers=0,
    attention_head_dim=4, num_attention_heads=1, joint_attention_dim=8,
    axes_dims_rope=(0, 2, 2), approximator_num_channels=16,
    approximator_hidden_dim=8,
)

print("guidance_embeds" in transformer.config)
print(transformer.config.guidance_embeds)  # AttributeError

Relevant precedent:
Flux only does this because FluxTransformer2DModel registers guidance_embeds:

@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: int | None = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(

Suggested fix:

# Remove the inpaint-only guidance_embeds block entirely.
# ChromaTransformer2DModel has no guidance input and the local `guidance` value is unused.

Issue 3: Gradient checkpointing drops Chroma attention masks in single blocks

Affected code:

if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
joint_attention_kwargs=joint_attention_kwargs,
)

for index_block, block in enumerate(self.single_transformer_blocks):
start_idx = 3 * index_block
temb = pooled_temb[:, start_idx : start_idx + 3]
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
joint_attention_kwargs=joint_attention_kwargs,
)

Problem:
The gradient-checkpointed dual-block path omits joint_attention_kwargs, and the single-block path omits both attention_mask and joint_attention_kwargs. With checkpointing enabled, the model no longer computes the same function.

Impact:
Training/fine-tuning with gradient checkpointing can silently ignore padding masks in single-stream blocks and diverge from non-checkpointed training.

Reproduction:

import torch
from diffusers import ChromaTransformer2DModel

torch.manual_seed(0)
model = ChromaTransformer2DModel(
    in_channels=4, out_channels=4, num_layers=0, num_single_layers=1,
    attention_head_dim=4, num_attention_heads=1, joint_attention_dim=8,
    axes_dims_rope=(0, 2, 2), approximator_num_channels=16,
    approximator_hidden_dim=8, approximator_layers=1,
).eval()

inputs = dict(
    hidden_states=torch.randn(1, 2, 4, requires_grad=True),
    encoder_hidden_states=torch.randn(1, 2, 8, requires_grad=True),
    timestep=torch.ones(1),
    txt_ids=torch.zeros(2, 3),
    img_ids=torch.zeros(2, 3),
    attention_mask=torch.tensor([[1, 0, 1, 1]], dtype=torch.bool),
)

out_no_ckpt = model(**inputs).sample.detach()
model.enable_gradient_checkpointing()
out_ckpt = model(**inputs).sample.detach()
print((out_no_ckpt - out_ckpt).abs().max().item())

Relevant precedent:
No exact duplicate found for ChromaTransformer2DModel attention_mask gradient_checkpointing.

Suggested fix:

encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
    block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, joint_attention_kwargs
)

hidden_states = self._gradient_checkpointing_func(
    block, hidden_states, temb, image_rotary_emb, attention_mask, joint_attention_kwargs
)

Issue 4: Chroma inpaint has no fast tests, and Chroma has no slow tests

Affected code:

class ChromaPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
):
pipeline_class = ChromaPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])

class ChromaImg2ImgPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
):
pipeline_class = ChromaImg2ImgPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])

Problem:
Only text2img and img2img fast pipeline tests exist. There is no ChromaInpaintPipeline fast test, and no Chroma slow tests anywhere under tests/.

Impact:
The inpaint runtime crash above is not covered, and there is no slow coverage against published Chroma checkpoints.

Reproduction:

from pathlib import Path

files = sorted(str(p).replace("\\", "/") for p in Path("tests").rglob("*chroma*.py"))
print("\n".join(files))
print("has_inpaint_fast_test", any("inpaint" in f for f in files))
print("has_slow_marker", any("@slow" in Path(f).read_text(encoding="utf-8") for f in files))

Relevant precedent:
Flux and QwenImage have inpaint fast-test classes:

class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
pipeline_class = FluxInpaintPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
def get_dummy_components(self):

class QwenImageInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = QwenImageInpaintPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True

Suggested fix:

# Add tests/pipelines/chroma/test_pipeline_chroma_inpainting.py using the existing
# Chroma tiny components plus image/mask tensors, and add at least one @slow Chroma
# pipeline test against a published Chroma checkpoint or saved test slices.

Issue 5: ChromaTransformer2DModel uses deprecated FluxPosEmbed import path

Affected code:

from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
from .transformer_flux import FluxAttention, FluxAttnProcessor

Problem:
FluxPosEmbed is imported from diffusers.models.embeddings, whose shim emits a deprecation warning and asks callers to import from diffusers.models.transformers.transformer_flux.

Impact:
Every Chroma transformer construction emits a user-visible FutureWarning.

Reproduction:

import warnings
from diffusers import ChromaTransformer2DModel

with warnings.catch_warnings(record=True) as caught:
    warnings.simplefilter("always")
    ChromaTransformer2DModel(
        in_channels=4, out_channels=4, num_layers=0, num_single_layers=0,
        attention_head_dim=4, num_attention_heads=1, joint_attention_dim=8,
        axes_dims_rope=(0, 2, 2), approximator_num_channels=16,
        approximator_hidden_dim=8,
    )

print([str(w.message) for w in caught if "FluxPosEmbed" in str(w.message)])

Relevant precedent:
The non-deprecated class lives here:

class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: list[int]):
super().__init__()
self.theta = theta

Suggested fix:

from ..embeddings import PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
from .transformer_flux import FluxAttention, FluxAttnProcessor, FluxPosEmbed

Issue 6: Chroma inpaint/output docs contain copied or AI-artifact text

Affected code:

"""
ChromaInpaintPipeline implements a text-guided image inpainting pipeline for the lodestones/Chroma1-HD model, based on
the ChromaPipeline from Hugging Face Diffusers:contentReference[oaicite:0]{index=0} and the Stable Diffusion inpainting
approach:contentReference[oaicite:1]{index=1}.

r"""
The Flux pipeline for image inpainting.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Args:
transformer ([`ChromaTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`DDIMScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).

"""
Output class for Stable Diffusion pipelines.
Args:
images (`list[PIL.Image.Image]` or `np.ndarray`)
list of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.

Problem:
The inpaint file includes contentReference[oaicite:*] artifacts and its class docstring describes Flux, Black Forest Labs, DDIM, CLIP, text_encoder_2, and tokenizer_2, none of which match this Chroma pipeline signature. ChromaPipelineOutput also says “Stable Diffusion pipelines.”

Impact:
Generated API docs are misleading and violate the repo review rule against ephemeral context/artifacts.

Reproduction:

from pathlib import Path

inpaint = Path("src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py").read_text()
output = Path("src/diffusers/pipelines/chroma/pipeline_output.py").read_text()

for token in ["contentReference[oaicite", "The Flux pipeline", "[`DDIMScheduler`]", "[`CLIPTextModel`]", "text_encoder_2"]:
    print(token, token in inpaint)
print("Stable Diffusion output doc", "Output class for Stable Diffusion pipelines." in output)

Relevant precedent:
No duplicate issue or PR found for the doc artifacts.

Suggested fix:

# Replace the copied inpaint docstring with Chroma-specific text:
# - Chroma image inpainting
# - FlowMatchEulerDiscreteScheduler
# - T5EncoderModel / T5TokenizerFast
# - no text_encoder_2/tokenizer_2/CLIP text encoder
# Also change ChromaPipelineOutput to "Output class for Chroma pipelines."

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions