Skip to content

lumina model/pipeline review #13634

@hlky

Description

@hlky

lumina model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Reviewed target files, top-level/lazy exports, dummy exports, docs, fast/slow tests, dtype/device handling, offload-related config, and attention processor behavior. Fast model and pipeline tests exist, and a slow Lumina pipeline test exists. Targeted pytest collection was attempted with .venv, but the local Torch build is missing torch._C._distributed_c10d, so shared test mixins fail during collection before Lumina tests run.

Duplicate search status: focused gh and GitHub connector searches found no exact duplicates for the Lumina findings. Related but not exact duplicates: #10827, #13613, #11368.

Issue 1: Deprecated alias is exported but cannot be constructed

Affected code:

class LuminaText2ImgPipeline(LuminaPipeline):
def __init__(
self,
transformer: LuminaNextDiT2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: GemmaPreTrainedModel,
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
):
deprecation_message = "`LuminaText2ImgPipeline` has been renamed to `LuminaPipeline` and will be removed in a future version. Please use `LuminaPipeline` instead."
deprecate("diffusers.pipelines.lumina.pipeline_lumina.LuminaText2ImgPipeline", "0.34", deprecation_message)

"LuminaPipeline",
"LuminaText2ImgPipeline",

Problem:
LuminaText2ImgPipeline is still publicly exported, but its constructor calls deprecate(..., "0.34", ...). Current version is 0.38.0.dev0, so construction raises ValueError instead of warning.

Impact:
Users can import the backwards-compatible alias, but any path that instantiates it fails immediately.

Reproduction:

from diffusers import LuminaText2ImgPipeline

LuminaText2ImgPipeline(None, None, None, None, None)

Relevant precedent:
Related rename PR, but it does not remove/fix the current exported alias failure: #10827

Suggested fix:

# Either remove the alias from pipeline_lumina.py, lazy exports, top-level exports,
# and dummy objects, or bump the deprecation target to a future version if keeping it.
deprecate(
    "diffusers.pipelines.lumina.pipeline_lumina.LuminaText2ImgPipeline",
    "1.0.0",
    deprecation_message,
)

Issue 2: VAE scale factor is hardcoded

Affected code:

self.vae_scale_factor = 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

Problem:
vae_scale_factor is hardcoded to 8 instead of derived from vae.config.block_out_channels.

Impact:
Custom/tiny VAEs compute wrong latent sizes, default image sizes, and input divisibility checks.

Reproduction:

from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaPipeline

transformer = LuminaNextDiT2DModel(sample_size=4, hidden_size=24, num_layers=1, num_attention_heads=3, num_kv_heads=1, multiple_of=16, learn_sigma=False, cross_attention_dim=32)
vae = AutoencoderKL(block_out_channels=(32, 64), down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D"), up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D"), latent_channels=4)
pipe = LuminaPipeline(transformer, FlowMatchEulerDiscreteScheduler(), vae, None, None)

print(pipe.vae_scale_factor)  # 8
print(2 ** (len(vae.config.block_out_channels) - 1))  # 2

Relevant precedent:

self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)

Suggested fix:

self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

Issue 3: max_sequence_length is ignored

Affected code:

def _get_gemma_prompt_embeds(
self,
prompt: str | list[str],
num_images_per_prompt: int = 1,
device: torch.device | None = None,
clean_caption: bool | None = False,
max_length: int | None = None,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
text_inputs = self.tokenizer(
prompt,
pad_to_multiple_of=8,
max_length=self.max_sequence_length,

) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)

Problem:
__call__ accepts max_sequence_length, but encode_prompt only captures it in **kwargs, and _get_gemma_prompt_embeds always uses self.max_sequence_length.

Impact:
Users cannot shorten or adjust prompt tokenization through the documented pipeline argument.

Reproduction:

import torch
from types import SimpleNamespace
from diffusers import LuminaPipeline

class Tok:
    def __call__(self, prompt, **kw):
        n = kw.get("max_length") or 11
        return SimpleNamespace(input_ids=torch.zeros(len(prompt), n, dtype=torch.long), attention_mask=torch.ones(len(prompt), n))
    def batch_decode(self, ids): return [""]

class Enc(torch.nn.Module):
    dtype = torch.float32
    def forward(self, input_ids, **kw):
        h = torch.zeros(input_ids.shape[0], input_ids.shape[1], 8)
        return SimpleNamespace(hidden_states=[h, h, h])

pipe = LuminaPipeline.__new__(LuminaPipeline)
pipe.max_sequence_length, pipe.tokenizer, pipe.text_encoder, pipe.transformer = 256, Tok(), Enc(), None
embeds, mask, *_ = pipe.encode_prompt("x", do_classifier_free_guidance=False, device=torch.device("cpu"), max_sequence_length=7)
print(embeds.shape, mask.shape)  # torch.Size([1, 256, 8]) torch.Size([1, 256])

Relevant precedent:

def _get_gemma_prompt_embeds(
self,
prompt: str | list[str],
device: torch.device | None = None,
max_sequence_length: int = 256,
) -> tuple[torch.Tensor, torch.Tensor]:
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because Gemma can only handle sequences up to"
f" {max_sequence_length} tokens: {removed_text}"

Suggested fix:

def _get_gemma_prompt_embeds(..., max_sequence_length: int = 256):
    ...
    max_length=max_sequence_length
    ...
    removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])

# and forward max_sequence_length from encode_prompt into _get_gemma_prompt_embeds

Issue 4: Prompt conditioning expansion is wrong for multiple images

Affected code:

_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)

if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
clean_caption=clean_caption,
)
# Get negative embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt if negative_prompt is not None else ""
# Normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
# Padding negative prompt to the same length with prompt
prompt_max_length = prompt_embeds.shape[1]
negative_text_inputs = self.tokenizer(
negative_prompt,
padding="max_length",
max_length=prompt_max_length,
truncation=True,
return_tensors="pt",
)
negative_text_input_ids = negative_text_inputs.input_ids.to(device)
negative_prompt_attention_mask = negative_text_inputs.attention_mask.to(device)
# Get the negative prompt embeddings
negative_prompt_embeds = self.text_encoder(
negative_text_input_ids,
attention_mask=negative_prompt_attention_mask,
output_hidden_states=True,
)
negative_dtype = self.text_encoder.dtype
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
_, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.to(dtype=negative_dtype, device=device)
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(
batch_size * num_images_per_prompt, -1
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask

Problem:
Generated prompt masks use repeat(num_images_per_prompt, 1), which orders masks as [p1, p2, p1, p2] while embeddings are [p1, p1, p2, p2]. Also, when users pass precomputed prompt/negative embeddings, encode_prompt does not expand them for num_images_per_prompt.

Impact:
Batched prompts with different mask lengths can pair embeddings with the wrong masks, and precomputed-embedding workflows fail or mis-broadcast when generating multiple images per prompt.

Reproduction:

import torch
from diffusers import LuminaPipeline

pipe = LuminaPipeline.__new__(LuminaPipeline)
pe = torch.zeros(1, 5, 8)
pm = torch.ones(1, 5, dtype=torch.long)
ne = torch.ones(1, 5, 8)
nm = torch.ones(1, 5, dtype=torch.long)

out = pipe.encode_prompt(
    prompt=None, do_classifier_free_guidance=True, num_images_per_prompt=2,
    device=torch.device("cpu"), prompt_embeds=pe, prompt_attention_mask=pm,
    negative_prompt_embeds=ne, negative_prompt_attention_mask=nm,
)
print(out[0].shape, out[2].shape)  # both still batch 1; expected batch 2 each

Relevant precedent:

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)

Suggested fix:

bs, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1).view(bs * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt).view(bs * num_images_per_prompt, -1)

if do_classifier_free_guidance:
    bs, seq_len, _ = negative_prompt_embeds.shape
    negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(bs * num_images_per_prompt, seq_len, -1)
    negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt).view(bs * num_images_per_prompt, -1)

Issue 5: Provided latents are not cast to the requested dtype

Affected code:

def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
return latents

Problem:
prepare_latents moves user-provided latents to the device but leaves dtype unchanged.

Impact:
Supplying float32 latents to a bf16/fp16 pipeline can feed mismatched activations into lower-precision transformer weights.

Reproduction:

import torch
from diffusers import LuminaPipeline

pipe = LuminaPipeline.__new__(LuminaPipeline)
pipe.vae_scale_factor = 8
latents = torch.ones(1, 4, 2, 2, dtype=torch.float32)
out = pipe.prepare_latents(1, 4, 16, 16, torch.bfloat16, torch.device("cpu"), None, latents)
print(out.dtype)  # torch.float32

Relevant precedent:

def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
if latents is not None:
return latents.to(device=device, dtype=dtype)

Suggested fix:

else:
    latents = latents.to(device=device, dtype=dtype)

Issue 6: Lumina attention bypasses the attention backend dispatcher

Affected code:

self.attn1 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
qk_norm="layer_norm_across_heads" if qk_norm else None,
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
out_bias=False,
processor=LuminaAttnProcessor2_0(),
)
self.attn1.to_out = nn.Identity()
# Cross-attention
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
dim_head=dim // num_attention_heads,
qk_norm="layer_norm_across_heads" if qk_norm else None,
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
out_bias=False,
processor=LuminaAttnProcessor2_0(),

class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
query_rotary_emb: torch.Tensor | None = None,
key_rotary_emb: torch.Tensor | None = None,
base_sequence_length: int | None = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
# Get Query-Key-Value Pair
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query_dim = query.shape[-1]
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype
# Get key-value heads
kv_heads = inner_dim // head_dim
# Apply Query-Key Norm if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, kv_heads, head_dim)
value = value.view(batch_size, -1, kv_heads, head_dim)
# Apply RoPE if needed
if query_rotary_emb is not None:
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
if key_rotary_emb is not None:
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
query, key = query.to(dtype), key.to(dtype)
# Apply proportional attention if true
if key_rotary_emb is None:
softmax_scale = None
else:
if base_sequence_length is not None:
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
else:
softmax_scale = attn.scale
# perform Grouped-qurey Attention (GQA)
n_rep = attn.heads // kv_heads
if n_rep >= 1:
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=softmax_scale
)
hidden_states = hidden_states.transpose(1, 2).to(dtype)
return hidden_states

Problem:
LuminaNextDiT2DModel uses the legacy shared Attention plus LuminaAttnProcessor2_0, whose processor calls F.scaled_dot_product_attention directly and has no _attention_backend / _parallel_config.

Impact:
set_attention_backend() and the attention_backend(...) context manager cannot actually select Flash/Sage/xFormers/context-parallel dispatch for Lumina attention.

Reproduction:

from diffusers import LuminaNextDiT2DModel

model = LuminaNextDiT2DModel(sample_size=4, hidden_size=24, num_layers=1, num_attention_heads=3, num_kv_heads=1, multiple_of=16, learn_sigma=False, cross_attention_dim=32)
processors = [m.processor for m in model.modules() if hasattr(m, "processor")]
print([type(p).__name__ for p in processors])
print([hasattr(p, "_attention_backend") for p in processors])
model.set_attention_backend("native")
print([getattr(p, "_attention_backend", None) for p in processors])

Relevant precedent:

class FluxAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)

Suggested fix:
Port Lumina attention to the current processor pattern: define the processor in the model file, add _attention_backend and _parallel_config, and call dispatch_attention_fn(...) instead of F.scaled_dot_product_attention(...).

Issue 7: Lumina quantization docs use the wrong component classes

Affected code:

```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, Transformer2DModel, LuminaPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = T5EncoderModel.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
subfolder="text_encoder",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = Transformer2DModel.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
)

Problem:
The docs load the Lumina text encoder with T5EncoderModel and the transformer with Transformer2DModel. The checkpoint config is Gemma text encoder plus LuminaNextDiT2DModel.

Impact:
Users following the quantization docs hit config/class errors before inference.

Reproduction:

from diffusers import LuminaNextDiT2DModel, Transformer2DModel
from transformers import AutoConfig, T5EncoderModel

repo = "Alpha-VLLM/Lumina-Next-SFT-diffusers"
text_config = AutoConfig.from_pretrained(repo, subfolder="text_encoder")
print(text_config.model_type)  # gemma
try:
    T5EncoderModel(text_config)
except Exception as e:
    print(type(e).__name__, str(e).splitlines()[0])

config = LuminaNextDiT2DModel.load_config(repo, subfolder="transformer")
try:
    Transformer2DModel.from_config(config)
except Exception as e:
    print(type(e).__name__, str(e).splitlines()[0])

Relevant precedent:
The pipeline type hints and tests use Gemma classes and LuminaNextDiT2DModel.

from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
LuminaNextDiT2DModel,
LuminaPipeline,
)

Suggested fix:

from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, LuminaNextDiT2DModel, LuminaPipeline
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig, GemmaModel

text_encoder_8bit = GemmaModel.from_pretrained(..., quantization_config=...)
transformer_8bit = LuminaNextDiT2DModel.from_pretrained(..., quantization_config=...)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions