z_image model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Duplicate search status: checked GitHub Issues/PRs for z_image, affected class names, and failure modes. Duplicates found for Issue 1 and Issue 3; noted below.
Issue 1: ZImageOmniPipeline crashes when guidance_scale=0
Affected code:
|
) |
|
condition_siglip_embeds = [[se.to(self.transformer.dtype) for se in sels] for sels in condition_siglip_embeds] |
|
if self.do_classifier_free_guidance: |
|
negative_condition_siglip_embeds = [[se.clone() for se in batch] for batch in condition_siglip_embeds] |
|
|
|
# Repeat prompt_embeds for num_images_per_prompt |
|
if num_images_per_prompt > 1: |
|
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] |
|
if self.do_classifier_free_guidance and negative_prompt_embeds: |
|
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] |
|
|
|
condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds] |
|
negative_condition_siglip_embeds = [ |
|
None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds |
Problem:
negative_condition_siglip_embeds is only assigned inside if self.do_classifier_free_guidance, but it is normalized unconditionally immediately after. The public example uses guidance_scale=0.0, so the documented Omni path raises before denoising.
Duplicate:
Already covered by open PR #13527. This is not a new finding.
Impact:
ZImageOmniPipeline(..., guidance_scale=0.0) fails for the documented turbo/no-CFG usage.
Reproduction:
import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImageOmniPipeline, ZImageTransformer2DModel
transformer = ZImageTransformer2DModel(
all_patch_size=(2,), all_f_patch_size=(1,), in_channels=4, dim=16,
n_layers=1, n_refiner_layers=1, n_heads=2, n_kv_heads=2,
cap_feat_dim=8, axes_dims=[4, 2, 2], axes_lens=[32, 32, 32],
)
vae = AutoencoderKL(
in_channels=3, out_channels=3, down_block_types=["DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D"],
block_out_channels=[16], layers_per_block=1, latent_channels=4, norm_num_groups=4, sample_size=32,
scaling_factor=0.3611, shift_factor=0.1159,
)
pipe = ZImageOmniPipeline(FlowMatchEulerDiscreteScheduler(), vae, None, None, transformer, None, None)
pipe(prompt_embeds=[[torch.randn(3, 8)]], height=32, width=32, num_inference_steps=1, guidance_scale=0.0, output_type="latent")
Relevant precedent:
Open duplicate PR: #13527
Suggested fix:
condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds]
if self.do_classifier_free_guidance:
negative_condition_siglip_embeds = [
None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds
]
else:
negative_condition_siglip_embeds = None
Issue 2: Omni condition-image encoding hard-casts VAE input to bfloat16
Affected code:
|
def prepare_image_latents( |
|
self, |
|
images: list[torch.Tensor], |
|
batch_size, |
|
device, |
|
dtype, |
|
): |
|
image_latents = [] |
|
for image in images: |
|
image = image.to(device=device, dtype=dtype) |
|
image_latent = ( |
|
self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor |
|
) * self.vae.config.scaling_factor |
|
image_latent = image_latent.unsqueeze(1).to(dtype) |
Problem:
prepare_image_latents() calls self.vae.encode(image.bfloat16()) regardless of the VAE dtype. A float32 VAE on CPU receives bf16 inputs with float32 weights and raises a dtype mismatch.
Impact:
Omni image-conditioned generation fails outside the exact bf16 VAE setup. It also violates the dtype/device handling rule by hardcoding a dtype in pipeline runtime code.
Reproduction:
import torch
from diffusers import AutoencoderKL, ZImageOmniPipeline
vae = AutoencoderKL(
in_channels=3, out_channels=3, down_block_types=["DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D"],
block_out_channels=[16], layers_per_block=1, latent_channels=4, norm_num_groups=4, sample_size=32,
scaling_factor=0.3611, shift_factor=0.1159,
)
pipe = object.__new__(ZImageOmniPipeline)
pipe.vae = vae
pipe.prepare_image_latents([torch.rand(1, 3, 32, 32)], 1, torch.device("cpu"), torch.float32)
Relevant precedent:
Other image-encoding paths convert to the requested/vae dtype before vae.encode, not a fixed bf16 dtype.
Suggested fix:
vae_dtype = self.vae.dtype
image = image.to(device=device, dtype=vae_dtype)
image_latent = (
self.vae.encode(image).latent_dist.mode()[0] - self.vae.config.shift_factor
) * self.vae.config.scaling_factor
image_latent = image_latent.unsqueeze(1).to(dtype)
Issue 3: ZImageControlNetModel has gradient-checkpointing flag but never initializes it
Affected code:
|
_supports_gradient_checkpointing = True |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
control_layers_places: list[int] = None, |
|
control_refiner_layers_places: list[int] = None, |
|
control_in_dim=None, |
|
add_control_noise_refiner: Literal["control_layers", "control_noise_refiner"] | None = None, |
|
all_patch_size=(2,), |
|
all_f_patch_size=(1,), |
|
dim=3840, |
|
n_refiner_layers=2, |
|
n_heads=30, |
|
n_kv_heads=30, |
|
norm_eps=1e-5, |
|
qk_norm=True, |
|
): |
|
super().__init__() |
|
self.control_layers_places = control_layers_places |
|
self.control_in_dim = control_in_dim |
|
self.control_refiner_layers_places = control_refiner_layers_places |
|
self.add_control_noise_refiner = add_control_noise_refiner |
|
|
|
assert 0 in self.control_layers_places |
|
|
|
# control blocks |
|
self.control_layers = nn.ModuleList( |
|
[ |
|
ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) |
|
for i in self.control_layers_places |
|
] |
|
) |
|
|
|
# control patch embeddings |
|
all_x_embedder = {} |
|
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): |
|
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) |
|
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder |
|
|
|
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) |
|
if self.add_control_noise_refiner == "control_layers": |
|
self.control_noise_refiner = None |
|
elif self.add_control_noise_refiner == "control_noise_refiner": |
|
self.control_noise_refiner = nn.ModuleList( |
|
[ |
|
ZImageControlTransformerBlock( |
|
1000 + layer_id, |
|
dim, |
|
n_heads, |
|
n_kv_heads, |
|
norm_eps, |
|
qk_norm, |
|
modulation=True, |
|
block_id=layer_id, |
|
) |
|
for layer_id in range(n_refiner_layers) |
|
] |
|
) |
|
else: |
|
self.control_noise_refiner = nn.ModuleList( |
|
[ |
|
ZImageTransformerBlock( |
|
1000 + layer_id, |
|
dim, |
|
n_heads, |
|
n_kv_heads, |
|
norm_eps, |
|
qk_norm, |
|
modulation=True, |
|
) |
|
for layer_id in range(n_refiner_layers) |
|
] |
|
) |
|
|
|
self.t_scale: float | None = None |
|
self.t_embedder: TimestepEmbedder | None = None |
|
self.all_x_embedder: nn.ModuleDict | None = None |
|
self.cap_embedder: nn.Sequential | None = None |
|
self.rope_embedder: RopeEmbedder | None = None |
|
self.noise_refiner: nn.ModuleList | None = None |
|
self.context_refiner: nn.ModuleList | None = None |
|
self.x_pad_token: nn.Parameter | None = None |
|
self.cap_pad_token: nn.Parameter | None = None |
|
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
for layer_idx, layer in enumerate(self.noise_refiner): |
|
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) |
|
if noise_refiner_block_samples is not None: |
|
if layer_idx in noise_refiner_block_samples: |
|
x = x + noise_refiner_block_samples[layer_idx] |
|
else: |
|
for layer_idx, layer in enumerate(self.noise_refiner): |
|
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) |
|
if noise_refiner_block_samples is not None: |
|
if layer_idx in noise_refiner_block_samples: |
|
x = x + noise_refiner_block_samples[layer_idx] |
|
|
|
# cap embed & refine |
|
cap_item_seqlens = [len(_) for _ in cap_feats] |
|
cap_max_item_seqlen = max(cap_item_seqlens) |
|
|
|
cap_feats = torch.cat(cap_feats, dim=0) |
|
cap_feats = self.cap_embedder(cap_feats) |
|
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token |
|
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) |
|
cap_freqs_cis = list( |
|
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) |
|
) |
|
|
|
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) |
|
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) |
|
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors |
|
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] |
|
|
|
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) |
|
for i, seq_len in enumerate(cap_item_seqlens): |
|
cap_attn_mask[i, :seq_len] = 1 |
|
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
for layer in self.context_refiner: |
|
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) |
|
else: |
|
for layer in self.context_refiner: |
|
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) |
|
|
|
# unified |
|
unified = [] |
|
unified_freqs_cis = [] |
|
for i in range(bsz): |
|
x_len = x_item_seqlens[i] |
|
cap_len = cap_item_seqlens[i] |
|
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) |
|
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) |
|
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] |
|
assert unified_item_seqlens == [len(_) for _ in unified] |
|
unified_max_item_seqlen = max(unified_item_seqlens) |
|
|
|
unified = pad_sequence(unified, batch_first=True, padding_value=0.0) |
|
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) |
|
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) |
|
for i, seq_len in enumerate(unified_item_seqlens): |
|
unified_attn_mask[i, :seq_len] = 1 |
|
|
|
## ControlNet start |
|
if not self.add_control_noise_refiner: |
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
for layer in self.control_noise_refiner: |
|
control_context = self._gradient_checkpointing_func( |
|
layer, control_context, x_attn_mask, x_freqs_cis, adaln_input |
|
) |
|
else: |
|
for layer in self.control_noise_refiner: |
|
control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) |
|
|
|
# unified |
|
control_context_unified = [] |
|
for i in range(bsz): |
|
x_len = x_item_seqlens[i] |
|
cap_len = cap_item_seqlens[i] |
|
control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) |
|
control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) |
|
|
|
for layer in self.control_layers: |
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
control_context_unified = self._gradient_checkpointing_func( |
Problem:
The model sets _supports_gradient_checkpointing = True and branches on self.gradient_checkpointing, but __init__ never sets self.gradient_checkpointing = False.
Duplicate:
Already covered by open PR #13267. This is not a new finding.
Impact:
Direct grad-enabled forward, training, or checkpointing setup fails with AttributeError.
Reproduction:
import torch
from diffusers import ZImageControlNetModel, ZImageTransformer2DModel
transformer = ZImageTransformer2DModel(
all_patch_size=(2,), all_f_patch_size=(1,), in_channels=4, dim=16,
n_layers=1, n_refiner_layers=1, n_heads=2, n_kv_heads=2,
cap_feat_dim=8, axes_dims=[4, 2, 2], axes_lens=[64, 64, 64],
)
controlnet = ZImageControlNetModel(
control_layers_places=[0], control_refiner_layers_places=[0], control_in_dim=4,
all_patch_size=(2,), all_f_patch_size=(1,), dim=16, n_refiner_layers=1,
n_heads=2, n_kv_heads=2,
)
controlnet = ZImageControlNetModel.from_transformer(controlnet, transformer)
controlnet([torch.randn(4, 1, 32, 32)], torch.tensor([0.5]), [torch.randn(3, 8)], [torch.randn(4, 1, 32, 32)])
Relevant precedent:
|
self.gradient_checkpointing = False |
Suggested fix:
self.gradient_checkpointing = False
Issue 4: ZImageInpaintPipeline.masked_image_latents is accepted but ignored
Affected code:
|
masked_image_latents: Optional[torch.FloatTensor] = None, |
|
strength: float = 1.0, |
|
height: int | None = None, |
|
width: int | None = None, |
|
num_inference_steps: int = 50, |
|
sigmas: list[float] | None = None, |
|
guidance_scale: float = 5.0, |
|
cfg_normalization: bool = False, |
|
cfg_truncation: float = 1.0, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
generator: torch.Generator | list[torch.Generator] | None = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
|
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
|
output_type: str = "pil", |
|
return_dict: bool = True, |
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
max_sequence_length: int = 512, |
|
): |
|
r""" |
|
Function invoked when calling the pipeline for inpainting. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
|
instead. |
|
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): |
|
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both |
|
numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a |
|
list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or |
|
a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. |
|
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): |
|
`Image`, numpy array or tensor representing a mask image for inpainting. White pixels (value 1) in the |
|
mask will be inpainted, black pixels (value 0) will be preserved from the original image. |
|
masked_image_latents (`torch.FloatTensor`, *optional*): |
|
Pre-encoded masked image latents. If provided, the masked image encoding step will be skipped. |
|
strength (`float`, *optional*, defaults to 1.0): |
|
if masked_image_latents is None: |
|
masked_image = init_image * (mask < 0.5) |
|
else: |
|
masked_image = None # Will use provided masked_image_latents |
|
|
|
mask, masked_image_latents = self.prepare_mask_latents( |
|
mask, |
|
masked_image if masked_image is not None else init_image, |
|
actual_batch_size, |
|
height, |
Problem:
The public argument says precomputed masked latents skip encoding, but the value is never passed into prepare_mask_latents() and never affects denoising. Different supplied masked_image_latents produce identical outputs.
Impact:
Users cannot actually provide precomputed masked latents, and the callback tensor implies a state value that does not participate in generation.
Reproduction:
import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImageInpaintPipeline, ZImageTransformer2DModel
torch.manual_seed(0)
transformer = ZImageTransformer2DModel(
all_patch_size=(2,), all_f_patch_size=(1,), in_channels=4, dim=16,
n_layers=1, n_refiner_layers=1, n_heads=2, n_kv_heads=2,
cap_feat_dim=8, axes_dims=[4, 2, 2], axes_lens=[64, 64, 64],
)
vae = AutoencoderKL(
in_channels=3, out_channels=3, down_block_types=["DownEncoderBlock2D"], up_block_types=["UpDecoderBlock2D"],
block_out_channels=[16], layers_per_block=1, latent_channels=4, norm_num_groups=4, sample_size=32,
scaling_factor=0.3611, shift_factor=0.1159,
)
pipe = ZImageInpaintPipeline(FlowMatchEulerDiscreteScheduler(), vae, None, None, transformer)
pipe.set_progress_bar_config(disable=True)
kwargs = dict(
prompt_embeds=[torch.randn(3, 8)], image=torch.rand(1, 3, 32, 32), mask_image=torch.ones(1, 1, 32, 32),
height=32, width=32, num_inference_steps=1, guidance_scale=0.0, output_type="latent",
latents=torch.randn(1, 4, 32, 32),
)
a = pipe(**kwargs, masked_image_latents=torch.zeros(1, 4, 32, 32), generator=torch.Generator().manual_seed(123)).images
b = pipe(**kwargs, masked_image_latents=torch.randn(1, 4, 32, 32), generator=torch.Generator().manual_seed(123)).images
print(torch.equal(a, b))
Relevant precedent:
|
if masked_image_latents is None: |
|
masked_image = init_image * (mask_condition < 0.5) |
|
else: |
|
masked_image = masked_image_latents |
|
|
|
mask, masked_image_latents = self.prepare_mask_latents( |
Suggested fix:
If Z-Image inpaint is intended to be latent-blending only, remove or deprecate masked_image_latents and the callback tensor. If it is intended to match SD-style inpaint conditioning, thread the provided tensor through prepare_mask_latents() and into the model input path.
Issue 5: Model dtype rules are violated in shared transformer/controlnet helpers
Affected code:
|
weight_dtype = self.mlp[0].weight.dtype |
|
def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0): |
|
with torch.device("cpu"): |
|
freqs_cis = [] |
|
for i, (d, e) in enumerate(zip(dim, end)): |
|
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) |
|
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) |
|
weight_dtype = self.mlp[0].weight.dtype |
|
def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0): |
|
with torch.device("cpu"): |
|
freqs_cis = [] |
|
for i, (d, e) in enumerate(zip(dim, end)): |
|
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) |
|
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) |
Problem:
TimestepEmbedder.forward() casts by reading self.mlp[0].weight.dtype, and RopeEmbedder.precompute_freqs_cis() unconditionally constructs float64 tensors. Both patterns are explicitly called out in the model review rules.
Impact:
This is fragile for quantized/GGUF/layerwise-casting loads and violates backend portability expectations for MPS/NPU-style environments.
Reproduction:
from pathlib import Path
for path in [
"src/diffusers/models/transformers/transformer_z_image.py",
"src/diffusers/models/controlnets/controlnet_z_image.py",
]:
for i, line in enumerate(Path(path).read_text().splitlines(), 1):
if "weight.dtype" in line or "torch.float64" in line:
print(path, i, line.strip())
Relevant precedent:
|
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 |
|
for i in range(n_axes): |
|
cos, sin = get_1d_rotary_pos_embed( |
|
self.axes_dim[i], |
|
pos[:, i], |
|
theta=self.theta, |
|
repeat_interleave_real=True, |
|
use_real=True, |
|
freqs_dtype=freqs_dtype, |
Suggested fix:
Use float32 for RoPE precompute unless there is measured need for gated float64, and pass the desired activation dtype from the caller into TimestepEmbedder instead of reading parameter storage dtype.
Issue 6: Modular pipeline generated docs still contain TODO placeholders
Affected code:
|
class ZImageCoreDenoiseStep(SequentialPipelineBlocks): |
|
""" |
|
denoise block that takes encoded conditions and runs the denoising process. |
|
|
|
Components: |
|
transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider |
|
(`ClassifierFreeGuidance`) |
|
|
|
Inputs: |
|
num_images_per_prompt (`None`, *optional*, defaults to 1): |
|
TODO: Add description. |
|
prompt_embeds (`list`): |
|
Pre-generated text embeddings. Can be generated from text_encoder step. |
|
negative_prompt_embeds (`list`, *optional*): |
|
Pre-generated negative text embeddings. Can be generated from text_encoder step. |
|
height (`int`, *optional*): |
|
TODO: Add description. |
|
width (`int`, *optional*): |
|
TODO: Add description. |
|
latents (`Tensor | NoneType`, *optional*): |
|
TODO: Add description. |
|
generator (`None`, *optional*): |
|
TODO: Add description. |
|
num_inference_steps (`None`, *optional*, defaults to 9): |
|
TODO: Add description. |
|
sigmas (`None`, *optional*): |
|
TODO: Add description. |
|
**denoiser_input_fields (`None`, *optional*): |
|
The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. |
|
|
Problem:
modular_blocks_z_image.py contains many generated TODO: Add description. entries. The modular review rules explicitly require generated modular docstrings to be fixed after running auto-docstring generation.
Impact:
Public modular pipeline docs/API metadata are incomplete for several inputs, including height, width, latents, generator, sigmas, and workflow-specific inputs.
Reproduction:
from pathlib import Path
path = Path("src/diffusers/modular_pipelines/z_image/modular_blocks_z_image.py")
print(sum("TODO: Add description." in line for line in path.read_text().splitlines()))
Relevant precedent:
.ai/modular.md conversion checklist requires running utils/modular_auto_docstring.py --fix_and_overwrite and resolving TODO placeholders.
Suggested fix:
Add accurate InputParam descriptions/types for the missing fields, rerun python utils/modular_auto_docstring.py --fix_and_overwrite, and verify no generated TODOs remain.
Issue 7: Coverage gaps: no slow tests, no ControlNet/Omni pipeline tests, and docs omit public variants
Affected code:
|
_import_structure["pipeline_output"] = ["ZImagePipelineOutput"] |
|
_import_structure["pipeline_z_image"] = ["ZImagePipeline"] |
|
_import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"] |
|
_import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"] |
|
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"] |
|
_import_structure["pipeline_z_image_inpaint"] = ["ZImageInpaintPipeline"] |
|
_import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"] |
|
class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
|
## ZImagePipeline |
|
|
|
[[autodoc]] ZImagePipeline |
|
- all |
|
- __call__ |
|
|
|
## ZImageImg2ImgPipeline |
|
|
|
[[autodoc]] ZImageImg2ImgPipeline |
|
- all |
|
- __call__ |
|
|
|
## ZImageInpaintPipeline |
|
|
|
[[autodoc]] ZImageInpaintPipeline |
Problem:
The package exports ZImageControlNetPipeline, ZImageControlNetInpaintPipeline, and ZImageOmniPipeline, but tests/pipelines/z_image/ only has fast tests for text2img/img2img/inpaint. There are no z_image slow tests. The pipeline docs only autodoc the three non-ControlNet/non-Omni pipelines.
Impact:
The exact Omni no-CFG crash and ControlNet checkpointing issue above are not covered by pipeline tests. Real-checkpoint regressions are also unguarded.
Reproduction:
from pathlib import Path
test_text = "\n".join(p.read_text() for p in Path("tests").rglob("*z_image*.py"))
docs = Path("docs/source/en/api/pipelines/z_image.md").read_text()
print("@slow" in test_text or "slow(" in test_text)
print("ZImageOmniPipeline" in test_text, "ZImageControlNetPipeline" in test_text)
print("ZImageOmniPipeline" in docs, "ZImageControlNetPipeline" in docs)
Relevant precedent:
Most mature pipeline families include at least one @slow real-checkpoint smoke test for public pipelines, plus fast tests for every exported variant.
Suggested fix:
Add fast tests for Omni and both ControlNet pipelines using tiny fixtures, add at least one slow real-checkpoint smoke test for the z_image family, and add autodoc sections for the public ControlNet and Omni pipelines.
Verification performed:
- Minimal
.venv snippets confirmed Issues 1-4.
tests/modular_pipelines/z_image/test_modular_pipeline_z_image.py -q: 14 passed.
- Pipeline fast tests could not be collected in this
.venv because the installed PyTorch build lacks torch._C._distributed_c10d, imported via shared training test utilities.
z_imagemodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Duplicate search status: checked GitHub Issues/PRs for
z_image, affected class names, and failure modes. Duplicates found for Issue 1 and Issue 3; noted below.Issue 1:
ZImageOmniPipelinecrashes whenguidance_scale=0Affected code:
diffusers/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py
Lines 579 to 592 in 0f1abc4
Problem:
negative_condition_siglip_embedsis only assigned insideif self.do_classifier_free_guidance, but it is normalized unconditionally immediately after. The public example usesguidance_scale=0.0, so the documented Omni path raises before denoising.Duplicate:
Already covered by open PR #13527. This is not a new finding.
Impact:
ZImageOmniPipeline(..., guidance_scale=0.0)fails for the documented turbo/no-CFG usage.Reproduction:
Relevant precedent:
Open duplicate PR: #13527
Suggested fix:
Issue 2: Omni condition-image encoding hard-casts VAE input to
bfloat16Affected code:
diffusers/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py
Lines 293 to 306 in 0f1abc4
Problem:
prepare_image_latents()callsself.vae.encode(image.bfloat16())regardless of the VAE dtype. A float32 VAE on CPU receives bf16 inputs with float32 weights and raises a dtype mismatch.Impact:
Omni image-conditioned generation fails outside the exact bf16 VAE setup. It also violates the dtype/device handling rule by hardcoding a dtype in pipeline runtime code.
Reproduction:
Relevant precedent:
Other image-encoding paths convert to the requested/vae dtype before
vae.encode, not a fixed bf16 dtype.Suggested fix:
Issue 3:
ZImageControlNetModelhas gradient-checkpointing flag but never initializes itAffected code:
diffusers/src/diffusers/models/controlnets/controlnet_z_image.py
Lines 433 to 517 in 0f1abc4
diffusers/src/diffusers/models/controlnets/controlnet_z_image.py
Lines 753 to 833 in 0f1abc4
Problem:
The model sets
_supports_gradient_checkpointing = Trueand branches onself.gradient_checkpointing, but__init__never setsself.gradient_checkpointing = False.Duplicate:
Already covered by open PR #13267. This is not a new finding.
Impact:
Direct grad-enabled forward, training, or checkpointing setup fails with
AttributeError.Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/controlnets/controlnet_qwenimage.py
Line 101 in 0f1abc4
Suggested fix:
Issue 4:
ZImageInpaintPipeline.masked_image_latentsis accepted but ignoredAffected code:
diffusers/src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py
Lines 537 to 576 in 0f1abc4
diffusers/src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py
Lines 788 to 797 in 0f1abc4
Problem:
The public argument says precomputed masked latents skip encoding, but the value is never passed into
prepare_mask_latents()and never affects denoising. Different suppliedmasked_image_latentsproduce identical outputs.Impact:
Users cannot actually provide precomputed masked latents, and the callback tensor implies a state value that does not participate in generation.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
Lines 1204 to 1209 in 0f1abc4
Suggested fix:
If Z-Image inpaint is intended to be latent-blending only, remove or deprecate
masked_image_latentsand the callback tensor. If it is intended to match SD-style inpaint conditioning, thread the provided tensor throughprepare_mask_latents()and into the model input path.Issue 5: Model dtype rules are violated in shared transformer/controlnet helpers
Affected code:
diffusers/src/diffusers/models/transformers/transformer_z_image.py
Line 65 in 0f1abc4
diffusers/src/diffusers/models/transformers/transformer_z_image.py
Lines 327 to 332 in 0f1abc4
diffusers/src/diffusers/models/controlnets/controlnet_z_image.py
Line 67 in 0f1abc4
diffusers/src/diffusers/models/controlnets/controlnet_z_image.py
Lines 304 to 309 in 0f1abc4
Problem:
TimestepEmbedder.forward()casts by readingself.mlp[0].weight.dtype, andRopeEmbedder.precompute_freqs_cis()unconditionally constructs float64 tensors. Both patterns are explicitly called out in the model review rules.Impact:
This is fragile for quantized/GGUF/layerwise-casting loads and violates backend portability expectations for MPS/NPU-style environments.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/transformer_flux.py
Lines 508 to 516 in 0f1abc4
Suggested fix:
Use float32 for RoPE precompute unless there is measured need for gated float64, and pass the desired activation dtype from the caller into
TimestepEmbedderinstead of reading parameter storage dtype.Issue 6: Modular pipeline generated docs still contain TODO placeholders
Affected code:
diffusers/src/diffusers/modular_pipelines/z_image/modular_blocks_z_image.py
Lines 46 to 75 in 0f1abc4
Problem:
modular_blocks_z_image.pycontains many generatedTODO: Add description.entries. The modular review rules explicitly require generated modular docstrings to be fixed after running auto-docstring generation.Impact:
Public modular pipeline docs/API metadata are incomplete for several inputs, including
height,width,latents,generator,sigmas, and workflow-specific inputs.Reproduction:
Relevant precedent:
.ai/modular.mdconversion checklist requires runningutils/modular_auto_docstring.py --fix_and_overwriteand resolving TODO placeholders.Suggested fix:
Add accurate
InputParamdescriptions/types for the missing fields, rerunpython utils/modular_auto_docstring.py --fix_and_overwrite, and verify no generated TODOs remain.Issue 7: Coverage gaps: no slow tests, no ControlNet/Omni pipeline tests, and docs omit public variants
Affected code:
diffusers/src/diffusers/pipelines/z_image/__init__.py
Lines 24 to 30 in 0f1abc4
diffusers/tests/pipelines/z_image/test_z_image.py
Line 45 in 0f1abc4
diffusers/docs/source/en/api/pipelines/z_image.md
Lines 91 to 105 in 0f1abc4
Problem:
The package exports
ZImageControlNetPipeline,ZImageControlNetInpaintPipeline, andZImageOmniPipeline, buttests/pipelines/z_image/only has fast tests for text2img/img2img/inpaint. There are no z_image slow tests. The pipeline docs only autodoc the three non-ControlNet/non-Omni pipelines.Impact:
The exact Omni no-CFG crash and ControlNet checkpointing issue above are not covered by pipeline tests. Real-checkpoint regressions are also unguarded.
Reproduction:
Relevant precedent:
Most mature pipeline families include at least one
@slowreal-checkpoint smoke test for public pipelines, plus fast tests for every exported variant.Suggested fix:
Add fast tests for Omni and both ControlNet pipelines using tiny fixtures, add at least one slow real-checkpoint smoke test for the z_image family, and add autodoc sections for the public ControlNet and Omni pipelines.
Verification performed:
.venvsnippets confirmed Issues 1-4.tests/modular_pipelines/z_image/test_modular_pipeline_z_image.py -q:14 passed..venvbecause the installed PyTorch build lackstorch._C._distributed_c10d, imported via shared training test utilities.