lumina model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Reviewed target files, top-level/lazy exports, dummy exports, docs, fast/slow tests, dtype/device handling, offload-related config, and attention processor behavior. Fast model and pipeline tests exist, and a slow Lumina pipeline test exists. Targeted pytest collection was attempted with .venv, but the local Torch build is missing torch._C._distributed_c10d, so shared test mixins fail during collection before Lumina tests run.
Duplicate search status: focused gh and GitHub connector searches found no exact duplicates for the Lumina findings. Related but not exact duplicates: #10827, #13613, #11368.
Issue 1: Deprecated alias is exported but cannot be constructed
Affected code:
|
class LuminaText2ImgPipeline(LuminaPipeline): |
|
def __init__( |
|
self, |
|
transformer: LuminaNextDiT2DModel, |
|
scheduler: FlowMatchEulerDiscreteScheduler, |
|
vae: AutoencoderKL, |
|
text_encoder: GemmaPreTrainedModel, |
|
tokenizer: GemmaTokenizer | GemmaTokenizerFast, |
|
): |
|
deprecation_message = "`LuminaText2ImgPipeline` has been renamed to `LuminaPipeline` and will be removed in a future version. Please use `LuminaPipeline` instead." |
|
deprecate("diffusers.pipelines.lumina.pipeline_lumina.LuminaText2ImgPipeline", "0.34", deprecation_message) |
|
"LuminaPipeline", |
|
"LuminaText2ImgPipeline", |
Problem:
LuminaText2ImgPipeline is still publicly exported, but its constructor calls deprecate(..., "0.34", ...). Current version is 0.38.0.dev0, so construction raises ValueError instead of warning.
Impact:
Users can import the backwards-compatible alias, but any path that instantiates it fails immediately.
Reproduction:
from diffusers import LuminaText2ImgPipeline
LuminaText2ImgPipeline(None, None, None, None, None)
Relevant precedent:
Related rename PR, but it does not remove/fix the current exported alias failure: #10827
Suggested fix:
# Either remove the alias from pipeline_lumina.py, lazy exports, top-level exports,
# and dummy objects, or bump the deprecation target to a future version if keeping it.
deprecate(
"diffusers.pipelines.lumina.pipeline_lumina.LuminaText2ImgPipeline",
"1.0.0",
deprecation_message,
)
Issue 2: VAE scale factor is hardcoded
Affected code:
|
self.vae_scale_factor = 8 |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
Problem:
vae_scale_factor is hardcoded to 8 instead of derived from vae.config.block_out_channels.
Impact:
Custom/tiny VAEs compute wrong latent sizes, default image sizes, and input divisibility checks.
Reproduction:
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline
transformer = LuminaNextDiT2DModel(sample_size=4, hidden_size=24, num_layers=1, num_attention_heads=3, num_kv_heads=1, multiple_of=16, learn_sigma=False, cross_attention_dim=32)
vae = AutoencoderKL(block_out_channels=(32, 64), down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D"), up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D"), latent_channels=4)
pipe = LuminaPipeline(transformer, FlowMatchEulerDiscreteScheduler(), vae, None, None)
print(pipe.vae_scale_factor) # 8
print(2 ** (len(vae.config.block_out_channels) - 1)) # 2
Relevant precedent:
|
self.register_modules( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
text_encoder_2=text_encoder_2, |
|
tokenizer=tokenizer, |
|
tokenizer_2=tokenizer_2, |
|
transformer=transformer, |
|
scheduler=scheduler, |
|
image_encoder=image_encoder, |
|
feature_extractor=feature_extractor, |
|
) |
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 |
|
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible |
|
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) |
Suggested fix:
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
Issue 3: max_sequence_length is ignored
Affected code:
|
def _get_gemma_prompt_embeds( |
|
self, |
|
prompt: str | list[str], |
|
num_images_per_prompt: int = 1, |
|
device: torch.device | None = None, |
|
clean_caption: bool | None = False, |
|
max_length: int | None = None, |
|
): |
|
device = device or self._execution_device |
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
batch_size = len(prompt) |
|
|
|
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) |
|
text_inputs = self.tokenizer( |
|
prompt, |
|
pad_to_multiple_of=8, |
|
max_length=self.max_sequence_length, |
|
) = self.encode_prompt( |
|
prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt=negative_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
device=device, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
prompt_attention_mask=prompt_attention_mask, |
|
negative_prompt_attention_mask=negative_prompt_attention_mask, |
|
clean_caption=clean_caption, |
|
max_sequence_length=max_sequence_length, |
|
) |
Problem:
__call__ accepts max_sequence_length, but encode_prompt only captures it in **kwargs, and _get_gemma_prompt_embeds always uses self.max_sequence_length.
Impact:
Users cannot shorten or adjust prompt tokenization through the documented pipeline argument.
Reproduction:
import torch
from types import SimpleNamespace
from diffusers import LuminaPipeline
class Tok:
def __call__(self, prompt, **kw):
n = kw.get("max_length") or 11
return SimpleNamespace(input_ids=torch.zeros(len(prompt), n, dtype=torch.long), attention_mask=torch.ones(len(prompt), n))
def batch_decode(self, ids): return [""]
class Enc(torch.nn.Module):
dtype = torch.float32
def forward(self, input_ids, **kw):
h = torch.zeros(input_ids.shape[0], input_ids.shape[1], 8)
return SimpleNamespace(hidden_states=[h, h, h])
pipe = LuminaPipeline.__new__(LuminaPipeline)
pipe.max_sequence_length, pipe.tokenizer, pipe.text_encoder, pipe.transformer = 256, Tok(), Enc(), None
embeds, mask, *_ = pipe.encode_prompt("x", do_classifier_free_guidance=False, device=torch.device("cpu"), max_sequence_length=7)
print(embeds.shape, mask.shape) # torch.Size([1, 256, 8]) torch.Size([1, 256])
Relevant precedent:
|
def _get_gemma_prompt_embeds( |
|
self, |
|
prompt: str | list[str], |
|
device: torch.device | None = None, |
|
max_sequence_length: int = 256, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
device = device or self._execution_device |
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
text_inputs = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=max_sequence_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
text_input_ids = text_inputs.input_ids.to(device) |
|
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) |
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): |
|
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) |
|
logger.warning( |
|
"The following part of your input was truncated because Gemma can only handle sequences up to" |
|
f" {max_sequence_length} tokens: {removed_text}" |
Suggested fix:
def _get_gemma_prompt_embeds(..., max_sequence_length: int = 256):
...
max_length=max_sequence_length
...
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
# and forward max_sequence_length from encode_prompt into _get_gemma_prompt_embeds
Issue 4: Prompt conditioning expansion is wrong for multiple images
Affected code:
|
_, seq_len, _ = prompt_embeds.shape |
|
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method |
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) |
|
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) |
|
if prompt_embeds is None: |
|
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( |
|
prompt=prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
device=device, |
|
clean_caption=clean_caption, |
|
) |
|
|
|
# Get negative embeddings for classifier free guidance |
|
if do_classifier_free_guidance and negative_prompt_embeds is None: |
|
negative_prompt = negative_prompt if negative_prompt is not None else "" |
|
|
|
# Normalize str to list |
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
|
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif isinstance(negative_prompt, str): |
|
negative_prompt = [negative_prompt] |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
# Padding negative prompt to the same length with prompt |
|
prompt_max_length = prompt_embeds.shape[1] |
|
negative_text_inputs = self.tokenizer( |
|
negative_prompt, |
|
padding="max_length", |
|
max_length=prompt_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
negative_text_input_ids = negative_text_inputs.input_ids.to(device) |
|
negative_prompt_attention_mask = negative_text_inputs.attention_mask.to(device) |
|
# Get the negative prompt embeddings |
|
negative_prompt_embeds = self.text_encoder( |
|
negative_text_input_ids, |
|
attention_mask=negative_prompt_attention_mask, |
|
output_hidden_states=True, |
|
) |
|
|
|
negative_dtype = self.text_encoder.dtype |
|
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] |
|
_, seq_len, _ = negative_prompt_embeds.shape |
|
|
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=negative_dtype, device=device) |
|
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method |
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) |
|
negative_prompt_attention_mask = negative_prompt_attention_mask.view( |
|
batch_size * num_images_per_prompt, -1 |
|
) |
|
|
|
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask |
Problem:
Generated prompt masks use repeat(num_images_per_prompt, 1), which orders masks as [p1, p2, p1, p2] while embeddings are [p1, p1, p2, p2]. Also, when users pass precomputed prompt/negative embeddings, encode_prompt does not expand them for num_images_per_prompt.
Impact:
Batched prompts with different mask lengths can pair embeddings with the wrong masks, and precomputed-embedding workflows fail or mis-broadcast when generating multiple images per prompt.
Reproduction:
import torch
from diffusers import LuminaPipeline
pipe = LuminaPipeline.__new__(LuminaPipeline)
pe = torch.zeros(1, 5, 8)
pm = torch.ones(1, 5, dtype=torch.long)
ne = torch.ones(1, 5, 8)
nm = torch.ones(1, 5, dtype=torch.long)
out = pipe.encode_prompt(
prompt=None, do_classifier_free_guidance=True, num_images_per_prompt=2,
device=torch.device("cpu"), prompt_embeds=pe, prompt_attention_mask=pm,
negative_prompt_embeds=ne, negative_prompt_attention_mask=nm,
)
print(out[0].shape, out[2].shape) # both still batch 1; expected batch 2 each
Relevant precedent:
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape |
|
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method |
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) |
|
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) |
|
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) |
|
|
|
# get unconditional embeddings for classifier free guidance |
|
if do_classifier_free_guidance and negative_prompt_embeds is None: |
|
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt |
|
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) |
|
max_length = prompt_embeds.shape[1] |
|
uncond_input = self.tokenizer( |
|
uncond_tokens, |
|
padding="max_length", |
|
max_length=max_length, |
|
truncation=True, |
|
return_attention_mask=True, |
|
add_special_tokens=True, |
|
return_tensors="pt", |
|
) |
|
negative_prompt_attention_mask = uncond_input.attention_mask |
|
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) |
|
|
|
negative_prompt_embeds = self.text_encoder( |
|
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask |
|
) |
|
negative_prompt_embeds = negative_prompt_embeds[0] |
|
|
|
if do_classifier_free_guidance: |
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method |
|
seq_len = negative_prompt_embeds.shape[1] |
|
|
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) |
|
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) |
|
|
|
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt) |
|
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1) |
Suggested fix:
bs, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(bs * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt).view(bs * num_images_per_prompt, -1)
if do_classifier_free_guidance:
bs, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(bs * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt).view(bs * num_images_per_prompt, -1)
Issue 5: Provided latents are not cast to the requested dtype
Affected code:
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
|
shape = ( |
|
batch_size, |
|
num_channels_latents, |
|
int(height) // self.vae_scale_factor, |
|
int(width) // self.vae_scale_factor, |
|
) |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if latents is None: |
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
else: |
|
latents = latents.to(device) |
|
|
|
return latents |
Problem:
prepare_latents moves user-provided latents to the device but leaves dtype unchanged.
Impact:
Supplying float32 latents to a bf16/fp16 pipeline can feed mismatched activations into lower-precision transformer weights.
Reproduction:
import torch
from diffusers import LuminaPipeline
pipe = LuminaPipeline.__new__(LuminaPipeline)
pipe.vae_scale_factor = 8
latents = torch.ones(1, 4, 2, 2, dtype=torch.float32)
out = pipe.prepare_latents(1, 4, 16, 16, torch.bfloat16, torch.device("cpu"), None, latents)
print(out.dtype) # torch.float32
Relevant precedent:
|
def prepare_latents( |
|
self, |
|
batch_size, |
|
num_channels_latents, |
|
height, |
|
width, |
|
dtype, |
|
device, |
|
generator, |
|
latents=None, |
|
): |
|
if latents is not None: |
|
return latents.to(device=device, dtype=dtype) |
|
|
Suggested fix:
else:
latents = latents.to(device=device, dtype=dtype)
Issue 6: Lumina attention bypasses the attention backend dispatcher
Affected code:
|
self.attn1 = Attention( |
|
query_dim=dim, |
|
cross_attention_dim=None, |
|
dim_head=dim // num_attention_heads, |
|
qk_norm="layer_norm_across_heads" if qk_norm else None, |
|
heads=num_attention_heads, |
|
kv_heads=num_kv_heads, |
|
eps=1e-5, |
|
bias=False, |
|
out_bias=False, |
|
processor=LuminaAttnProcessor2_0(), |
|
) |
|
self.attn1.to_out = nn.Identity() |
|
|
|
# Cross-attention |
|
self.attn2 = Attention( |
|
query_dim=dim, |
|
cross_attention_dim=cross_attention_dim, |
|
dim_head=dim // num_attention_heads, |
|
qk_norm="layer_norm_across_heads" if qk_norm else None, |
|
heads=num_attention_heads, |
|
kv_heads=num_kv_heads, |
|
eps=1e-5, |
|
bias=False, |
|
out_bias=False, |
|
processor=LuminaAttnProcessor2_0(), |
|
class LuminaAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is |
|
used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector. |
|
""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor | None = None, |
|
query_rotary_emb: torch.Tensor | None = None, |
|
key_rotary_emb: torch.Tensor | None = None, |
|
base_sequence_length: int | None = None, |
|
) -> torch.Tensor: |
|
from .embeddings import apply_rotary_emb |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = hidden_states.shape |
|
|
|
# Get Query-Key-Value Pair |
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
query_dim = query.shape[-1] |
|
inner_dim = key.shape[-1] |
|
head_dim = query_dim // attn.heads |
|
dtype = query.dtype |
|
|
|
# Get key-value heads |
|
kv_heads = inner_dim // head_dim |
|
|
|
# Apply Query-Key Norm if needed |
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim) |
|
|
|
key = key.view(batch_size, -1, kv_heads, head_dim) |
|
value = value.view(batch_size, -1, kv_heads, head_dim) |
|
|
|
# Apply RoPE if needed |
|
if query_rotary_emb is not None: |
|
query = apply_rotary_emb(query, query_rotary_emb, use_real=False) |
|
if key_rotary_emb is not None: |
|
key = apply_rotary_emb(key, key_rotary_emb, use_real=False) |
|
|
|
query, key = query.to(dtype), key.to(dtype) |
|
|
|
# Apply proportional attention if true |
|
if key_rotary_emb is None: |
|
softmax_scale = None |
|
else: |
|
if base_sequence_length is not None: |
|
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale |
|
else: |
|
softmax_scale = attn.scale |
|
|
|
# perform Grouped-qurey Attention (GQA) |
|
n_rep = attn.heads // kv_heads |
|
if n_rep >= 1: |
|
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) |
|
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) |
|
|
|
# scaled_dot_product_attention expects attention_mask shape to be |
|
# (batch, heads, source_length, target_length) |
|
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) |
|
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) |
|
|
|
query = query.transpose(1, 2) |
|
key = key.transpose(1, 2) |
|
value = value.transpose(1, 2) |
|
|
|
# the output of sdp = (batch, num_heads, seq_len, head_dim) |
|
# TODO: add support for attn.scale when we move to Torch 2.1 |
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, scale=softmax_scale |
|
) |
|
hidden_states = hidden_states.transpose(1, 2).to(dtype) |
|
|
|
return hidden_states |
Problem:
LuminaNextDiT2DModel uses the legacy shared Attention plus LuminaAttnProcessor2_0, whose processor calls F.scaled_dot_product_attention directly and has no _attention_backend / _parallel_config.
Impact:
set_attention_backend() and the attention_backend(...) context manager cannot actually select Flash/Sage/xFormers/context-parallel dispatch for Lumina attention.
Reproduction:
from diffusers import LuminaNextDiT2DModel
model = LuminaNextDiT2DModel(sample_size=4, hidden_size=24, num_layers=1, num_attention_heads=3, num_kv_heads=1, multiple_of=16, learn_sigma=False, cross_attention_dim=32)
processors = [m.processor for m in model.modules() if hasattr(m, "processor")]
print([type(p).__name__ for p in processors])
print([hasattr(p, "_attention_backend") for p in processors])
model.set_attention_backend("native")
print([getattr(p, "_attention_backend", None) for p in processors])
Relevant precedent:
|
class FluxAttnProcessor: |
|
_attention_backend = None |
|
_parallel_config = None |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") |
|
|
|
def __call__( |
|
self, |
|
attn: "FluxAttention", |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor = None, |
|
attention_mask: torch.Tensor | None = None, |
|
image_rotary_emb: torch.Tensor | None = None, |
|
) -> torch.Tensor: |
|
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( |
|
attn, hidden_states, encoder_hidden_states |
|
) |
|
|
|
query = query.unflatten(-1, (attn.heads, -1)) |
|
key = key.unflatten(-1, (attn.heads, -1)) |
|
value = value.unflatten(-1, (attn.heads, -1)) |
|
|
|
query = attn.norm_q(query) |
|
key = attn.norm_k(key) |
|
|
|
if attn.added_kv_proj_dim is not None: |
|
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) |
|
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) |
|
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) |
|
|
|
encoder_query = attn.norm_added_q(encoder_query) |
|
encoder_key = attn.norm_added_k(encoder_key) |
|
|
|
query = torch.cat([encoder_query, query], dim=1) |
|
key = torch.cat([encoder_key, key], dim=1) |
|
value = torch.cat([encoder_value, value], dim=1) |
|
|
|
if image_rotary_emb is not None: |
|
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
|
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
|
|
|
hidden_states = dispatch_attention_fn( |
|
query, |
|
key, |
|
value, |
|
attn_mask=attention_mask, |
|
backend=self._attention_backend, |
|
parallel_config=self._parallel_config, |
|
) |
Suggested fix:
Port Lumina attention to the current processor pattern: define the processor in the model file, add _attention_backend and _parallel_config, and call dispatch_attention_fn(...) instead of F.scaled_dot_product_attention(...).
Issue 7: Lumina quantization docs use the wrong component classes
Affected code:
|
```py |
|
import torch |
|
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaPipeline |
|
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel |
|
|
|
quant_config = BitsAndBytesConfig(load_in_8bit=True) |
|
text_encoder_8bit = T5EncoderModel.from_pretrained( |
|
"Alpha-VLLM/Lumina-Next-SFT-diffusers", |
|
subfolder="text_encoder", |
|
quantization_config=quant_config, |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) |
|
transformer_8bit = Transformer2DModel.from_pretrained( |
|
"Alpha-VLLM/Lumina-Next-SFT-diffusers", |
|
subfolder="transformer", |
|
quantization_config=quant_config, |
|
torch_dtype=torch.float16, |
|
) |
Problem:
The docs load the Lumina text encoder with T5EncoderModel and the transformer with Transformer2DModel. The checkpoint config is Gemma text encoder plus LuminaNextDiT2DModel.
Impact:
Users following the quantization docs hit config/class errors before inference.
Reproduction:
from diffusers import LuminaNextDiT2DModel, Transformer2DModel
from transformers import AutoConfig, T5EncoderModel
repo = "Alpha-VLLM/Lumina-Next-SFT-diffusers"
text_config = AutoConfig.from_pretrained(repo, subfolder="text_encoder")
print(text_config.model_type) # gemma
try:
T5EncoderModel(text_config)
except Exception as e:
print(type(e).__name__, str(e).splitlines()[0])
config = LuminaNextDiT2DModel.load_config(repo, subfolder="transformer")
try:
Transformer2DModel.from_config(config)
except Exception as e:
print(type(e).__name__, str(e).splitlines()[0])
Relevant precedent:
The pipeline type hints and tests use Gemma classes and LuminaNextDiT2DModel.
|
from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM |
|
|
|
from diffusers import ( |
|
AutoencoderKL, |
|
FlowMatchEulerDiscreteScheduler, |
|
LuminaNextDiT2DModel, |
|
LuminaPipeline, |
|
) |
Suggested fix:
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, LuminaNextDiT2DModel, LuminaPipeline
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig, GemmaModel
text_encoder_8bit = GemmaModel.from_pretrained(..., quantization_config=...)
transformer_8bit = LuminaNextDiT2DModel.from_pretrained(..., quantization_config=...)
luminamodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Reviewed target files, top-level/lazy exports, dummy exports, docs, fast/slow tests, dtype/device handling, offload-related config, and attention processor behavior. Fast model and pipeline tests exist, and a slow Lumina pipeline test exists. Targeted pytest collection was attempted with
.venv, but the local Torch build is missingtorch._C._distributed_c10d, so shared test mixins fail during collection before Lumina tests run.Duplicate search status: focused
ghand GitHub connector searches found no exact duplicates for the Lumina findings. Related but not exact duplicates: #10827, #13613, #11368.Issue 1: Deprecated alias is exported but cannot be constructed
Affected code:
diffusers/src/diffusers/pipelines/lumina/pipeline_lumina.py
Lines 940 to 950 in 0f1abc4
diffusers/src/diffusers/__init__.py
Lines 624 to 625 in 0f1abc4
Problem:
LuminaText2ImgPipelineis still publicly exported, but its constructor callsdeprecate(..., "0.34", ...). Current version is0.38.0.dev0, so construction raisesValueErrorinstead of warning.Impact:
Users can import the backwards-compatible alias, but any path that instantiates it fails immediately.
Reproduction:
Relevant precedent:
Related rename PR, but it does not remove/fix the current exported alias failure: #10827
Suggested fix:
Issue 2: VAE scale factor is hardcoded
Affected code:
diffusers/src/diffusers/pipelines/lumina/pipeline_lumina.py
Lines 196 to 197 in 0f1abc4
Problem:
vae_scale_factoris hardcoded to8instead of derived fromvae.config.block_out_channels.Impact:
Custom/tiny VAEs compute wrong latent sizes, default image sizes, and input divisibility checks.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/flux/pipeline_flux.py
Lines 198 to 212 in 0f1abc4
Suggested fix:
Issue 3:
max_sequence_lengthis ignoredAffected code:
diffusers/src/diffusers/pipelines/lumina/pipeline_lumina.py
Lines 206 to 222 in 0f1abc4
diffusers/src/diffusers/pipelines/lumina/pipeline_lumina.py
Lines 782 to 794 in 0f1abc4
Problem:
__call__acceptsmax_sequence_length, butencode_promptonly captures it in**kwargs, and_get_gemma_prompt_embedsalways usesself.max_sequence_length.Impact:
Users cannot shorten or adjust prompt tokenization through the documented pipeline argument.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
Lines 192 to 215 in 0f1abc4
Suggested fix:
Issue 4: Prompt conditioning expansion is wrong for multiple images
Affected code:
diffusers/src/diffusers/pipelines/lumina/pipeline_lumina.py
Lines 252 to 257 in 0f1abc4
diffusers/src/diffusers/pipelines/lumina/pipeline_lumina.py
Lines 310 to 369 in 0f1abc4
Problem:
Generated prompt masks use
repeat(num_images_per_prompt, 1), which orders masks as[p1, p2, p1, p2]while embeddings are[p1, p1, p2, p2]. Also, when users pass precomputed prompt/negative embeddings,encode_promptdoes not expand them fornum_images_per_prompt.Impact:
Batched prompts with different mask lengths can pair embeddings with the wrong masks, and precomputed-embedding workflows fail or mis-broadcast when generating multiple images per prompt.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Lines 390 to 431 in 0f1abc4
Suggested fix:
Issue 5: Provided latents are not cast to the requested dtype
Affected code:
diffusers/src/diffusers/pipelines/lumina/pipeline_lumina.py
Lines 597 to 615 in 0f1abc4
Problem:
prepare_latentsmoves user-provided latents to the device but leaves dtype unchanged.Impact:
Supplying float32 latents to a bf16/fp16 pipeline can feed mismatched activations into lower-precision transformer weights.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
Lines 633 to 646 in 0f1abc4
Suggested fix:
Issue 6: Lumina attention bypasses the attention backend dispatcher
Affected code:
diffusers/src/diffusers/models/transformers/lumina_nextdit2d.py
Lines 71 to 96 in 0f1abc4
diffusers/src/diffusers/models/attention_processor.py
Lines 3572 to 3665 in 0f1abc4
Problem:
LuminaNextDiT2DModeluses the legacy sharedAttentionplusLuminaAttnProcessor2_0, whose processor callsF.scaled_dot_product_attentiondirectly and has no_attention_backend/_parallel_config.Impact:
set_attention_backend()and theattention_backend(...)context manager cannot actually select Flash/Sage/xFormers/context-parallel dispatch for Lumina attention.Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/transformer_flux.py
Lines 75 to 125 in 0f1abc4
Suggested fix:
Port Lumina attention to the current processor pattern: define the processor in the model file, add
_attention_backendand_parallel_config, and calldispatch_attention_fn(...)instead ofF.scaled_dot_product_attention(...).Issue 7: Lumina quantization docs use the wrong component classes
Affected code:
diffusers/docs/source/en/api/pipelines/lumina.md
Lines 88 to 107 in 0f1abc4
Problem:
The docs load the Lumina text encoder with
T5EncoderModeland the transformer withTransformer2DModel. The checkpoint config is Gemma text encoder plusLuminaNextDiT2DModel.Impact:
Users following the quantization docs hit config/class errors before inference.
Reproduction:
Relevant precedent:
The pipeline type hints and tests use Gemma classes and
LuminaNextDiT2DModel.diffusers/tests/pipelines/lumina/test_lumina_nextdit.py
Lines 6 to 13 in 0f1abc4
Suggested fix: