bria_fibo model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Duplicate search: checked GitHub issues/PRs for bria_fibo, BriaFibo, FIBO, affected class names, and failure modes. Found integration/refactor PRs #12545, #12688, #12731, #12930, #13341; no duplicate for the issues below. Public top-level imports succeeded.
Test note: standalone reproductions ran with .venv. Full target pytest collection failed before running tests because this Windows torch build lacks torch._C._distributed_c10d.
Issue 1: prompt_embeds is public but unusable
Affected code:
|
if prompt_embeds is None: |
|
prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( |
|
prompt=prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
) |
|
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) |
|
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] |
|
|
|
if guidance_scale > 1: |
|
if isinstance(negative_prompt, list) and negative_prompt[0] is None: |
|
negative_prompt = "" |
|
negative_prompt = negative_prompt or "" |
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
|
|
negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( |
|
prompt=negative_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
) |
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) |
|
negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] |
|
|
|
if self.text_encoder is not None: |
|
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: |
|
# Retrieve the original scale by scaling back the LoRA layers |
|
unscale_lora_layers(self.text_encoder, lora_scale) |
|
|
|
# Pad to longest |
|
if prompt_attention_mask is not None: |
|
prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) |
|
|
|
if negative_prompt_embeds is not None: |
|
if negative_prompt_attention_mask is not None: |
|
negative_prompt_attention_mask = negative_prompt_attention_mask.to( |
|
device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype |
|
) |
|
max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) |
|
|
|
prompt_embeds, prompt_attention_mask = self.pad_embedding( |
|
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask |
|
) |
|
prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] |
|
|
|
negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( |
|
negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask |
|
) |
|
negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] |
|
else: |
|
max_tokens = prompt_embeds.shape[1] |
|
prompt_embeds, prompt_attention_mask = self.pad_embedding( |
|
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask |
|
) |
|
negative_prompt_layers = None |
|
|
|
dtype = self.text_encoder.dtype |
|
text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) |
|
|
|
return ( |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
text_ids, |
|
prompt_attention_mask, |
|
negative_prompt_attention_mask, |
|
prompt_layers, |
|
negative_prompt_layers, |
|
if prompt_embeds is None: |
|
prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( |
|
prompt=prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
) |
|
prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) |
|
prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] |
|
|
|
if guidance_scale > 1: |
|
if isinstance(negative_prompt, list) and negative_prompt[0] is None: |
|
negative_prompt = "" |
|
negative_prompt = negative_prompt or "" |
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
|
if prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
|
|
negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( |
|
prompt=negative_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
max_sequence_length=max_sequence_length, |
|
device=device, |
|
) |
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) |
|
negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] |
|
|
|
if self.text_encoder is not None: |
|
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: |
|
# Retrieve the original scale by scaling back the LoRA layers |
|
unscale_lora_layers(self.text_encoder, lora_scale) |
|
|
|
# Pad to longest |
|
if prompt_attention_mask is not None: |
|
prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) |
|
|
|
if negative_prompt_embeds is not None: |
|
if negative_prompt_attention_mask is not None: |
|
negative_prompt_attention_mask = negative_prompt_attention_mask.to( |
|
device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype |
|
) |
|
max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) |
|
|
|
prompt_embeds, prompt_attention_mask = self.pad_embedding( |
|
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask |
|
) |
|
prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] |
|
|
|
negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( |
|
negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask |
|
) |
|
negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] |
|
else: |
|
max_tokens = prompt_embeds.shape[1] |
|
prompt_embeds, prompt_attention_mask = self.pad_embedding( |
|
prompt_embeds, max_tokens, attention_mask=prompt_attention_mask |
|
) |
|
negative_prompt_layers = None |
|
|
|
dtype = self.text_encoder.dtype |
|
text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) |
|
|
|
return ( |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
text_ids, |
|
prompt_attention_mask, |
|
negative_prompt_attention_mask, |
|
prompt_layers, |
|
negative_prompt_layers, |
Problem:
Both pipelines expose prompt_embeds, but encode_prompt() only defines prompt_layers when it encodes prompt itself. Passing precomputed embeddings raises UnboundLocalError. negative_prompt_embeds is also not honored because negative embeddings are always recomputed when guidance_scale > 1.
Impact:
Users cannot use documented precomputed embedding workflows, prompt weighting, cached text encoder outputs, or callback-modified embeddings reliably.
Reproduction:
import torch
from types import SimpleNamespace
from diffusers import BriaFiboPipeline
pipe = BriaFiboPipeline.__new__(BriaFiboPipeline)
pipe.transformer = SimpleNamespace(dtype=torch.float32)
pipe.text_encoder = SimpleNamespace(dtype=torch.float32)
pipe.encode_prompt(
prompt=None,
prompt_embeds=torch.zeros(1, 1, 64),
guidance_scale=1.0,
device=torch.device("cpu"),
)
Relevant precedent:
Flux validates all required precomputed conditioning inputs together:
|
if prompt_embeds is not None and pooled_prompt_embeds is None: |
|
raise ValueError( |
|
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." |
|
) |
|
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: |
|
raise ValueError( |
|
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." |
Suggested fix:
if prompt_embeds is not None and prompt_layers is None:
raise ValueError("`prompt_embeds` requires precomputed `prompt_layers`, or pass `prompt` instead.")
Better: add public prompt_layers / negative_prompt_layers inputs and honor negative_prompt_embeds instead of recomputing it.
Issue 2: Custom timesteps are ignored
Affected code:
|
timesteps: list[int] = None, |
|
guidance_scale: float = 5, |
|
negative_prompt: str | list[str] | None = None, |
|
num_images_per_prompt: int | None = 1, |
|
generator: torch.Generator | list[torch.Generator] | None = None, |
|
latents: torch.FloatTensor | None = None, |
|
prompt_embeds: torch.FloatTensor | None = None, |
|
negative_prompt_embeds: torch.FloatTensor | None = None, |
|
output_type: str | None = "pil", |
|
return_dict: bool = True, |
|
joint_attention_kwargs: dict[str, Any] | None = None, |
|
callback_on_step_end: Callable[[int, int], None] | None = None, |
|
callback_on_step_end_tensor_inputs: list[str] = ["latents"], |
|
max_sequence_length: int = 3000, |
|
do_patching=False, |
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
self.scheduler, |
|
num_inference_steps=num_inference_steps, |
|
device=device, |
|
timesteps=None, |
|
sigmas=sigmas, |
|
mu=mu, |
|
timesteps: List[int] = None, |
|
seed: int | None = None, |
|
guidance_scale: float = 5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
generator: torch.Generator | list[torch.Generator] | None = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: str = "pil", |
|
return_dict: bool = True, |
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
max_sequence_length: int = 3000, |
|
do_patching=False, |
|
_auto_resize: bool = True, |
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
self.scheduler, |
|
num_inference_steps=num_inference_steps, |
|
device=device, |
|
timesteps=None, |
|
sigmas=sigmas, |
|
mu=mu, |
Problem:
timesteps is accepted and documented, but both pipelines call retrieve_timesteps(..., timesteps=None, sigmas=sigmas, ...).
Impact:
Users requesting a custom timestep schedule silently get the default schedule.
Reproduction:
import inspect
from diffusers import BriaFiboPipeline, BriaFiboEditPipeline
for cls in (BriaFiboPipeline, BriaFiboEditPipeline):
source = inspect.getsource(cls.__call__)
assert "timesteps=None" not in source, f"{cls.__name__} drops custom timesteps"
Relevant precedent:
The shared helper supports timesteps:
|
def retrieve_timesteps( |
|
scheduler, |
|
num_inference_steps: int | None = None, |
|
device: str | torch.device | None = None, |
|
timesteps: list[int] | None = None, |
|
sigmas: list[float] | None = None, |
|
**kwargs, |
|
): |
|
r""" |
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
|
|
|
Args: |
|
scheduler (`SchedulerMixin`): |
|
The scheduler to get timesteps from. |
|
num_inference_steps (`int`): |
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
|
must be `None`. |
|
device (`str` or `torch.device`, *optional*): |
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
timesteps (`list[int]`, *optional*): |
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
|
`num_inference_steps` and `sigmas` must be `None`. |
|
sigmas (`list[float]`, *optional*): |
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
|
`num_inference_steps` and `timesteps` must be `None`. |
|
|
|
Returns: |
|
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
|
second element is the number of inference steps. |
|
""" |
|
if timesteps is not None and sigmas is not None: |
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
|
if timesteps is not None: |
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accepts_timesteps: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" timestep schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
Suggested fix:
sigmas = None if timesteps is not None else np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps=num_inference_steps,
device=device,
timesteps=timesteps,
sigmas=sigmas,
mu=mu,
)
Issue 3: Tensor images crash in BriaFiboEditPipeline
Affected code:
|
# Preprocess image |
|
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): |
|
image = self.image_processor.resize(image, height, width) |
|
image = self.image_processor.preprocess(image, height, width) |
Problem:
The tensor-image path reads self.latent_channels, but that attribute is never defined.
Impact:
image=torch.Tensor(...) is accepted by validation and documented typing, but crashes before preprocessing.
Reproduction:
import torch
from diffusers import BriaFiboEditPipeline
pipe = BriaFiboEditPipeline.__new__(BriaFiboEditPipeline)
image = torch.zeros(1, 3, 32, 32)
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == pipe.latent_channels):
pass
Relevant precedent:
N/A.
Suggested fix:
if image is not None:
image = self.image_processor.resize(image, height, width)
image = self.image_processor.preprocess(image, height, width)
If latent image input is intended, define the latent channel count from self.transformer.config.in_channels and validate it explicitly.
Issue 4: Multiple generated images get malformed output shape
Affected code:
|
image = [] |
|
for scaled_latent in latents_scaled: |
|
curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] |
|
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) |
|
image.append(curr_image) |
|
if len(image) == 1: |
|
image = image[0] |
|
else: |
|
image = np.stack(image, axis=0) |
|
image = [] |
|
for scaled_latent in latents_scaled: |
|
curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] |
|
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) |
|
image.append(curr_image) |
|
if len(image) == 1: |
|
image = image[0] |
|
else: |
|
image = np.stack(image, axis=0) |
Problem:
Each per-sample postprocess(..., output_type="np") returns shape (1, H, W, C), then the pipeline uses np.stack, producing (N, 1, H, W, C) instead of (N, H, W, C). PIL output becomes nested lists.
Impact:
num_images_per_prompt > 1 returns an incompatible output structure.
Reproduction:
import numpy as np
per_sample = [np.zeros((1, 32, 32, 3)), np.zeros((1, 32, 32, 3))]
print(np.stack(per_sample, axis=0).shape) # (2, 1, 32, 32, 3)
Relevant precedent:
Bria decodes and postprocesses the batch directly:
|
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) |
|
latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor |
|
image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0] |
|
image = self.image_processor.postprocess(image, output_type=output_type) |
Suggested fix:
curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
if output_type == "np":
image.append(curr_image[0])
else:
image.extend(curr_image)
...
if output_type == "np":
image = np.stack(image, axis=0)
Issue 5: Edit pipeline does not duplicate image latents for num_images_per_prompt
Affected code:
|
if image is not None: |
|
image_latents, image_ids = self.prepare_image_latents( |
|
image=image, |
|
batch_size=batch_size * num_images_per_prompt, |
|
num_channels_latents=num_channels_latents, |
|
height=height, |
|
width=width, |
|
dtype=prompt_embeds.dtype, |
|
device=device, |
|
generator=generator, |
|
) |
|
latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0) # dim 0 is sequence dimension |
|
image_latents_cthw = self.vae.encode(image.unsqueeze(2)).latent_dist.mean |
|
latents_scaled = [(latent - latents_mean) * latents_std for latent in image_latents_cthw] |
|
image_latents_cthw = torch.concat(latents_scaled, dim=0) |
|
image_latents_bchw = image_latents_cthw[:, :, 0, :, :] |
|
|
|
image_latent_height, image_latent_width = image_latents_bchw.shape[2:] |
|
image_latents_bsd = self._pack_latents_no_patch( |
|
latents=image_latents_bchw, |
|
batch_size=batch_size, |
|
num_channels_latents=num_channels_latents, |
|
height=image_latent_height, |
|
width=image_latent_width, |
|
) |
Problem:
prepare_image_latents() receives batch_size * num_images_per_prompt, but the encoded image batch remains at the original image batch size before reshape.
Impact:
BriaFiboEditPipeline(..., image=..., num_images_per_prompt=2) fails with an invalid reshape. The fast tests skip batching, so this is not covered.
Reproduction:
import torch
from diffusers import BriaFiboEditPipeline
latents = torch.zeros(1, 16, 2, 2)
BriaFiboEditPipeline._pack_latents_no_patch(
latents=latents,
batch_size=2,
num_channels_latents=16,
height=2,
width=2,
)
Relevant precedent:
N/A.
Suggested fix:
repeat_by = batch_size // image_latents_bchw.shape[0]
image_latents_bchw = image_latents_bchw.repeat_interleave(repeat_by, dim=0)
Also unskip/add batch tests for edit.
Issue 6: guidance_embeds=True cannot construct the transformer
Affected code:
|
class BriaFiboTimestepProjEmbeddings(nn.Module): |
|
def __init__(self, embedding_dim, time_theta): |
|
super().__init__() |
|
|
|
self.time_proj = BriaFiboTimesteps( |
|
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta |
|
) |
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
|
|
|
def forward(self, timestep, dtype): |
|
timesteps_proj = self.time_proj(timestep) |
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D) |
|
if guidance_embeds: |
|
self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim) |
|
if guidance: |
|
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype) |
Problem:
BriaFiboTimestepProjEmbeddings requires time_theta, but guidance_embed is constructed without it. The forward path also uses if guidance: on a tensor.
Impact:
Any config/checkpoint with guidance_embeds=True fails during model construction, and the forward branch would be ambiguous for multi-element tensors after construction is fixed.
Reproduction:
from diffusers import BriaFiboTransformer2DModel
BriaFiboTransformer2DModel(
patch_size=1,
in_channels=16,
num_layers=1,
num_single_layers=1,
attention_head_dim=8,
num_attention_heads=2,
joint_attention_dim=64,
text_encoder_dim=32,
axes_dims_rope=[0, 4, 4],
guidance_embeds=True,
)
Relevant precedent:
Flux checks guidance is None, not tensor truthiness:
|
timestep = timestep.to(hidden_states.dtype) * 1000 |
|
if guidance is not None: |
|
guidance = guidance.to(hidden_states.dtype) * 1000 |
|
|
|
temb = ( |
|
self.time_text_embed(timestep, pooled_projections) |
|
if guidance is None |
|
else self.time_text_embed(timestep, guidance, pooled_projections) |
|
) |
Suggested fix:
if guidance_embeds:
self.guidance_embed = BriaFiboTimestepProjEmbeddings(
embedding_dim=self.inner_dim,
time_theta=time_theta,
)
...
if guidance is not None:
temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
Issue 7: Dense additive attention masks disable flash/sage attention
Affected code:
|
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) |
|
attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq |
|
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting |
|
|
|
if self._joint_attention_kwargs is None: |
|
self._joint_attention_kwargs = {} |
|
self._joint_attention_kwargs["attention_mask"] = attention_mask |
|
if image_latents is None: |
|
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) |
|
else: |
|
image_latent_attention_mask = torch.ones( |
|
[image_latents.shape[0], image_latents.shape[1]], |
|
dtype=image_latents.dtype, |
|
device=image_latents.device, |
|
) |
|
if guidance_scale > 1: |
|
image_latent_attention_mask = image_latent_attention_mask.repeat(2, 1) |
|
attention_mask = torch.cat( |
|
[prompt_attention_mask, latent_attention_mask, image_latent_attention_mask], dim=1 |
|
) |
|
|
|
attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq |
|
attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting |
|
|
|
if self._joint_attention_kwargs is None: |
|
self._joint_attention_kwargs = {} |
|
self._joint_attention_kwargs["attention_mask"] = attention_mask |
|
hidden_states = dispatch_attention_fn( |
|
query, |
|
key, |
|
value, |
|
attn_mask=attention_mask, |
|
backend=self._attention_backend, |
|
parallel_config=self._parallel_config, |
|
) |
Problem:
The pipelines convert padding masks into dense (B, 1, L, L) additive float masks. This is only padding information and can be represented as a bool key mask, but dense masks hard-fail for flash-attn and sage backends.
Impact:
Users selecting optimized attention backends hit avoidable runtime failures.
Reproduction:
import torch
from diffusers.pipelines.bria_fibo.pipeline_bria_fibo import BriaFiboPipeline
from diffusers.models.attention_dispatch import _flash_attention
mask = torch.tensor([[1, 1, 0, 1]], dtype=torch.float32)
dense_mask = BriaFiboPipeline._prepare_attention_mask(mask).unsqueeze(1)
q = k = v = torch.randn(1, 4, 2, 8)
_flash_attention(q, k, v, attn_mask=dense_mask)
Relevant precedent:
QwenImage builds a bool joint mask instead:
|
if encoder_hidden_states_mask is not None: |
|
# Build joint mask: [text_mask, all_ones_for_image] |
|
batch_size, image_seq_len = hidden_states.shape[:2] |
|
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) |
|
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) |
|
joint_attention_mask = joint_attention_mask[:, None, None, :] |
|
block_attention_kwargs["attention_mask"] = joint_attention_mask |
Suggested fix:
attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1).to(torch.bool)
attention_mask = attention_mask[:, None, None, :]
self._joint_attention_kwargs["attention_mask"] = attention_mask
Issue 8: VAE scale factor is hardcoded instead of read from config
Affected code:
|
self.vae_scale_factor = 16 |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) |
|
self.default_sample_size = 64 |
|
self.vae_scale_factor = 16 |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # * 2) |
|
self.default_sample_size = 32 # 64 |
Problem:
Both pipelines set self.vae_scale_factor = 16 even though AutoencoderKLWan stores scale_factor_spatial in config.
Impact:
Custom or future Fibo-compatible VAEs with a different spatial scale serialize/load correctly but produce wrong latent sizes in the pipeline.
Reproduction:
from diffusers import AutoencoderKLWan, BriaFiboPipeline
vae = AutoencoderKLWan(base_dim=8, decoder_base_dim=8, num_res_blocks=1, z_dim=4, dim_mult=[1], temperal_downsample=[])
pipe = BriaFiboPipeline(transformer=None, scheduler=None, vae=vae, text_encoder=None, tokenizer=None)
print(vae.config.scale_factor_spatial) # 8
print(pipe.vae_scale_factor) # 16
Relevant precedent:
Wan reads the VAE scale factor from config:
|
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 |
|
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 |
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) |
Suggested fix:
self.vae_scale_factor = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 16
Issue 9: Transformer is missing _no_split_modules
Affected code:
|
class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): |
|
""" |
|
Parameters: |
|
patch_size (`int`): Patch size to turn the input data into small patches. |
|
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. |
|
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. |
|
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. |
|
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. |
|
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. |
|
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. |
|
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. |
|
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. |
|
... |
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
Problem:
The model enables gradient checkpointing but does not declare _no_split_modules for its transformer blocks.
Impact:
device_map / offload placement can split residual attention blocks across devices, unlike comparable transformer integrations.
Reproduction:
from diffusers import BriaFiboTransformer2DModel
print(getattr(BriaFiboTransformer2DModel, "_no_split_modules", None))
Relevant precedent:
Flux declares both block classes:
|
_supports_gradient_checkpointing = True |
|
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] |
|
_skip_layerwise_casting_patterns = ["pos_embed", "norm"] |
Suggested fix:
_no_split_modules = ["BriaFiboTransformerBlock", "BriaFiboSingleTransformerBlock"]
Issue 10: Edit example docstring is stale and not runnable
Affected code:
|
# TODO: Update example docstring |
|
EXAMPLE_DOC_STRING = """ |
|
Example: |
|
```python |
|
import torch |
|
from diffusers import BriaFiboEditPipeline |
|
from diffusers.modular_pipelines import ModularPipeline |
|
|
|
torch.set_grad_enabled(False) |
|
vlm_pipe = ModularPipelineBlocks.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) |
|
vlm_pipe = vlm_pipe.init_pipeline() |
|
|
|
pipe = BriaFiboEditPipeline.from_pretrained( |
|
"briaai/fibo-edit", |
|
torch_dtype=torch.bfloat16, |
|
) |
|
pipe.to("cuda") |
|
|
|
output = vlm_pipe( |
|
prompt="A hyper-detailed, ultra-fluffy owl sitting in the trees at night, looking directly at the camera with wide, adorable, expressive eyes. Its feathers are soft and voluminous, catching the cool moonlight with subtle silver highlights. The owl's gaze is curious and full of charm, giving it a whimsical, storybook-like personality." |
|
) |
|
json_prompt_generate = json.loads(output.values["json_prompt"]) |
|
|
|
image = Image.open("image_generate.png") |
|
|
|
edit_prompt = "Make the owl to be a cat" |
|
|
|
json_prompt_generate["edit_instruction"] = edit_prompt |
|
|
|
results_generate = pipe( |
|
prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=3.5, image=image, output_type="np" |
|
) |
|
``` |
|
""" |
Problem:
The file contains # TODO: Update example docstring, imports ModularPipeline, then uses undefined ModularPipelineBlocks.
Impact:
Generated docs include a broken example, and the TODO violates the review rule against ephemeral PR-context comments.
Reproduction:
namespace = {}
exec(
"from diffusers.modular_pipelines import ModularPipeline\n"
"ModularPipelineBlocks.from_pretrained('briaai/FIBO-VLM-prompt-to-JSON')",
namespace,
)
Relevant precedent:
The text-to-image Fibo example uses ModularPipeline.from_pretrained:
|
from diffusers import BriaFiboPipeline |
|
from diffusers.modular_pipelines import ModularPipeline |
|
|
|
torch.set_grad_enabled(False) |
|
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) |
|
|
|
pipe = BriaFiboPipeline.from_pretrained( |
Suggested fix:
vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
Also remove the TODO and verify the model id casing.
Issue 11: Slow tests are missing
Affected code:
|
class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase): |
|
model_class = BriaFiboTransformer2DModel |
|
main_input_name = "hidden_states" |
|
# We override the items here because the transformer under consideration is small. |
|
model_split_percents = [0.8, 0.7, 0.7] |
|
|
|
# Skip setting testing with default: AttnProcessor |
|
uses_custom_attn_processor = True |
|
|
|
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
|
pipeline_class = BriaFiboPipeline |
|
params = frozenset(["prompt", "height", "width", "guidance_scale"]) |
|
batch_params = frozenset(["prompt"]) |
|
test_xformers_attention = False |
|
test_layerwise_casting = False |
|
test_group_offloading = False |
|
supports_dduf = False |
|
|
|
class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
|
pipeline_class = BriaFiboEditPipeline |
|
params = frozenset(["prompt", "height", "width", "guidance_scale"]) |
|
batch_params = frozenset(["prompt"]) |
|
test_xformers_attention = False |
|
test_layerwise_casting = False |
|
test_group_offloading = False |
|
supports_dduf = False |
Problem:
Fast model and pipeline tests exist, but there are no @slow tests for BriaFiboPipeline or BriaFiboEditPipeline.
Impact:
The gated real checkpoints are never exercised for loading, dtype/offload behavior, output shape, JSON prompt handling, or edit image/mask behavior.
Reproduction:
from pathlib import Path
paths = [
Path("tests/models/transformers/test_models_transformer_bria_fibo.py"),
Path("tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py"),
Path("tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py"),
]
print({str(path): "@slow" in path.read_text() for path in paths})
Relevant precedent:
Bria has a slow pipeline test:
|
@slow |
|
@require_torch_accelerator |
|
class BriaPipelineSlowTests(unittest.TestCase): |
|
pipeline_class = BriaPipeline |
|
repo_id = "briaai/BRIA-3.2" |
Suggested fix:
Add gated slow smoke tests for briaai/FIBO and briaai/Fibo-Edit, using torch_dtype=torch.bfloat16, enable_model_cpu_offload(), a short schedule, deterministic seed/generator, and expected output shape/value slices.
bria_fibomodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Duplicate search: checked GitHub issues/PRs for
bria_fibo,BriaFibo,FIBO, affected class names, and failure modes. Found integration/refactor PRs#12545,#12688,#12731,#12930,#13341; no duplicate for the issues below. Public top-level imports succeeded.Test note: standalone reproductions ran with
.venv. Full target pytest collection failed before running tests because this Windows torch build lackstorch._C._distributed_c10d.Issue 1:
prompt_embedsis public but unusableAffected code:
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
Lines 254 to 332 in 0f1abc4
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Lines 412 to 490 in 0f1abc4
Problem:
Both pipelines expose
prompt_embeds, butencode_prompt()only definesprompt_layerswhen it encodespromptitself. Passing precomputed embeddings raisesUnboundLocalError.negative_prompt_embedsis also not honored because negative embeddings are always recomputed whenguidance_scale > 1.Impact:
Users cannot use documented precomputed embedding workflows, prompt weighting, cached text encoder outputs, or callback-modified embeddings reliably.
Reproduction:
Relevant precedent:
Flux validates all required precomputed conditioning inputs together:
diffusers/src/diffusers/pipelines/flux/pipeline_flux.py
Lines 494 to 500 in 0f1abc4
Suggested fix:
Better: add public
prompt_layers/negative_prompt_layersinputs and honornegative_prompt_embedsinstead of recomputing it.Issue 2: Custom
timestepsare ignoredAffected code:
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
Lines 466 to 480 in 0f1abc4
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
Lines 674 to 680 in 0f1abc4
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Lines 626 to 642 in 0f1abc4
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Lines 892 to 898 in 0f1abc4
Problem:
timestepsis accepted and documented, but both pipelines callretrieve_timesteps(..., timesteps=None, sigmas=sigmas, ...).Impact:
Users requesting a custom timestep schedule silently get the default schedule.
Reproduction:
Relevant precedent:
The shared helper supports
timesteps:diffusers/src/diffusers/pipelines/flux/pipeline_flux.py
Lines 88 to 129 in 0f1abc4
Suggested fix:
Issue 3: Tensor images crash in
BriaFiboEditPipelineAffected code:
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Lines 809 to 812 in 0f1abc4
Problem:
The tensor-image path reads
self.latent_channels, but that attribute is never defined.Impact:
image=torch.Tensor(...)is accepted by validation and documented typing, but crashes before preprocessing.Reproduction:
Relevant precedent:
N/A.
Suggested fix:
If latent image input is intended, define the latent channel count from
self.transformer.config.in_channelsand validate it explicitly.Issue 4: Multiple generated images get malformed output shape
Affected code:
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
Lines 772 to 780 in 0f1abc4
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Lines 995 to 1003 in 0f1abc4
Problem:
Each per-sample
postprocess(..., output_type="np")returns shape(1, H, W, C), then the pipeline usesnp.stack, producing(N, 1, H, W, C)instead of(N, H, W, C). PIL output becomes nested lists.Impact:
num_images_per_prompt > 1returns an incompatible output structure.Reproduction:
Relevant precedent:
Bria decodes and postprocesses the batch directly:
diffusers/src/diffusers/pipelines/bria/pipeline_bria.py
Lines 730 to 733 in 0f1abc4
Suggested fix:
Issue 5: Edit pipeline does not duplicate image latents for
num_images_per_promptAffected code:
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Lines 831 to 842 in 0f1abc4
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Lines 1037 to 1049 in 0f1abc4
Problem:
prepare_image_latents()receivesbatch_size * num_images_per_prompt, but the encoded image batch remains at the original image batch size before reshape.Impact:
BriaFiboEditPipeline(..., image=..., num_images_per_prompt=2)fails with an invalid reshape. The fast tests skip batching, so this is not covered.Reproduction:
Relevant precedent:
N/A.
Suggested fix:
Also unskip/add batch tests for edit.
Issue 6:
guidance_embeds=Truecannot construct the transformerAffected code:
diffusers/src/diffusers/models/transformers/transformer_bria_fibo.py
Lines 415 to 426 in 0f1abc4
diffusers/src/diffusers/models/transformers/transformer_bria_fibo.py
Lines 472 to 473 in 0f1abc4
diffusers/src/diffusers/models/transformers/transformer_bria_fibo.py
Lines 558 to 559 in 0f1abc4
Problem:
BriaFiboTimestepProjEmbeddingsrequirestime_theta, butguidance_embedis constructed without it. The forward path also usesif guidance:on a tensor.Impact:
Any config/checkpoint with
guidance_embeds=Truefails during model construction, and the forward branch would be ambiguous for multi-element tensors after construction is fixed.Reproduction:
Relevant precedent:
Flux checks
guidance is None, not tensor truthiness:diffusers/src/diffusers/models/transformers/transformer_flux.py
Lines 682 to 690 in 0f1abc4
Suggested fix:
Issue 7: Dense additive attention masks disable flash/sage attention
Affected code:
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
Lines 647 to 653 in 0f1abc4
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Lines 852 to 871 in 0f1abc4
diffusers/src/diffusers/models/transformers/transformer_bria_fibo.py
Lines 111 to 118 in 0f1abc4
Problem:
The pipelines convert padding masks into dense
(B, 1, L, L)additive float masks. This is only padding information and can be represented as a bool key mask, but dense masks hard-fail for flash-attn and sage backends.Impact:
Users selecting optimized attention backends hit avoidable runtime failures.
Reproduction:
Relevant precedent:
QwenImage builds a bool joint mask instead:
diffusers/src/diffusers/models/transformers/transformer_qwenimage.py
Lines 946 to 952 in 0f1abc4
Suggested fix:
Issue 8: VAE scale factor is hardcoded instead of read from config
Affected code:
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
Lines 110 to 112 in 0f1abc4
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Lines 268 to 270 in 0f1abc4
Problem:
Both pipelines set
self.vae_scale_factor = 16even thoughAutoencoderKLWanstoresscale_factor_spatialin config.Impact:
Custom or future Fibo-compatible VAEs with a different spatial scale serialize/load correctly but produce wrong latent sizes in the pipeline.
Reproduction:
Relevant precedent:
Wan reads the VAE scale factor from config:
diffusers/src/diffusers/pipelines/wan/pipeline_wan.py
Lines 154 to 156 in 0f1abc4
Suggested fix:
Issue 9: Transformer is missing
_no_split_modulesAffected code:
diffusers/src/diffusers/models/transformers/transformer_bria_fibo.py
Lines 430 to 445 in 0f1abc4
Problem:
The model enables gradient checkpointing but does not declare
_no_split_modulesfor its transformer blocks.Impact:
device_map/ offload placement can split residual attention blocks across devices, unlike comparable transformer integrations.Reproduction:
Relevant precedent:
Flux declares both block classes:
diffusers/src/diffusers/models/transformers/transformer_flux.py
Lines 565 to 567 in 0f1abc4
Suggested fix:
Issue 10: Edit example docstring is stale and not runnable
Affected code:
diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py
Lines 53 to 86 in 0f1abc4
Problem:
The file contains
# TODO: Update example docstring, importsModularPipeline, then uses undefinedModularPipelineBlocks.Impact:
Generated docs include a broken example, and the TODO violates the review rule against ephemeral PR-context comments.
Reproduction:
Relevant precedent:
The text-to-image Fibo example uses
ModularPipeline.from_pretrained:diffusers/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
Lines 51 to 57 in 0f1abc4
Suggested fix:
Also remove the TODO and verify the model id casing.
Issue 11: Slow tests are missing
Affected code:
diffusers/tests/models/transformers/test_models_transformer_bria_fibo.py
Lines 29 to 37 in 0f1abc4
diffusers/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py
Lines 39 to 47 in 0f1abc4
diffusers/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py
Lines 40 to 47 in 0f1abc4
Problem:
Fast model and pipeline tests exist, but there are no
@slowtests forBriaFiboPipelineorBriaFiboEditPipeline.Impact:
The gated real checkpoints are never exercised for loading, dtype/offload behavior, output shape, JSON prompt handling, or edit image/mask behavior.
Reproduction:
Relevant precedent:
Bria has a slow pipeline test:
diffusers/tests/pipelines/bria/test_pipeline_bria.py
Lines 241 to 245 in 0f1abc4
Suggested fix:
Add gated slow smoke tests for
briaai/FIBOandbriaai/Fibo-Edit, usingtorch_dtype=torch.bfloat16,enable_model_cpu_offload(), a short schedule, deterministic seed/generator, and expected output shape/value slices.