Skip to content

[Pipelines] Implement Z-Image ModuleV2 pipeline#21

Open
byungchul-sqzb wants to merge 1 commit into
byungchul-sqzb/stack/1from
byungchul-sqzb/stack/2
Open

[Pipelines] Implement Z-Image ModuleV2 pipeline#21
byungchul-sqzb wants to merge 1 commit into
byungchul-sqzb/stack/1from
byungchul-sqzb/stack/2

Conversation

@byungchul-sqzb
Copy link
Copy Markdown
Collaborator

[Pipelines] Implement Z-Image ModuleV2 pipeline

Port Z-Image to the Graph API / ModuleV2 runtime using V2 text encoder,
transformer, and VAE components.

Restore the current ModuleV3 feature set and behavior in the V2 path,
including:

  • Z-Image transformer/model/config/weight adapter wiring
  • diffusion pipeline, arch registration, and ModuleV2/ModuleV3 selection via
    --prefer-module-v3
  • batched CFG, CFG renormalization, and image parity with ModuleV3
  • transformer-side RoPE micro-optimizations:
    • single unified RoPE embedder call
    • interleaved [cos, sin] frequency generation
    • rope_ragged_with_position_ids hot path
    • preamble dtype cast and direct modulation slicing

Port Z-Image to the Graph API / ModuleV2 runtime using V2 text encoder,
transformer, and VAE components.

Restore the current ModuleV3 feature set and behavior in the V2 path,
including:
- Z-Image transformer/model/config/weight adapter wiring
- diffusion pipeline, arch registration, and ModuleV2/ModuleV3 selection via
  --prefer-module-v3
- batched CFG, CFG renormalization, and image parity with ModuleV3
- transformer-side RoPE micro-optimizations:
  - single unified RoPE embedder call
  - interleaved [cos, sin] frequency generation
  - rope_ragged_with_position_ids hot path
  - preamble dtype cast and direct modulation slicing

stack-info: PR: #21, branch: byungchul-sqzb/stack/2
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the Z-Image diffusion architecture, providing both standard and ModuleV3 versions. The changes include the core DiT model, attention mechanisms with rotary embeddings, and the generation pipeline. Review feedback identifies a critical shape mismatch in the attention layer's position IDs, a logic error in the CFG renormalization process where the target norm is incorrectly calculated, and a performance bottleneck caused by frequent host-to-device buffer transfers within the execution loop.

Comment on lines +41 to +46
position_ids = ops.range(
0, seq_len, dtype=DType.uint32, device=query.device
)
position_ids = ops.broadcast_to(
ops.unsqueeze(position_ids, 0), [batch_size, seq_len]
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The position_ids tensor must be flattened to 1D to match the first dimension of the ragged input tensors (query_ragged, key_ragged) passed to the rope_ragged_with_position_ids kernel. Currently, it is a 2D tensor of shape [batch_size, seq_len], which will cause a shape mismatch or incorrect indexing in the kernel.

Suggested change
position_ids = ops.range(
0, seq_len, dtype=DType.uint32, device=query.device
)
position_ids = ops.broadcast_to(
ops.unsqueeze(position_ids, 0), [batch_size, seq_len]
)
position_ids = ops.range(
0, seq_len, dtype=DType.uint32, device=query.device
)
position_ids = ops.broadcast_to(
ops.unsqueeze(position_ids, 0), [batch_size, seq_len]
)
position_ids = ops.reshape(position_ids, [batch_size * seq_len])

Comment on lines +1017 to +1043
with Tracer("transformer"):
noise_pred = self.transformer(
latents,
prompt_embeds,
timestep,
img_ids,
txt_ids,
)[0]
assert negative_prompt_embeds is not None
with Tracer("cfg_transformer"):
neg_noise_pred = self.transformer(
latents,
negative_prompt_embeds,
timestep,
neg_img_ids,
neg_txt_ids,
)[0]
assert guidance_buf is not None
noise_pred = self._cfg_combine(
noise_pred, neg_noise_pred, guidance_buf
)
if model_inputs.cfg_normalization:
noise_pred = self._cfg_renormalization(
noise_pred,
noise_pred,
)
else:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In the non-batched CFG path (used when explicit_negative_prompt is True), the renormalization logic is incorrect. It passes the CFG result (noise_pred) as both the pos and pred arguments to _cfg_renormalization. This makes the renormalization a no-op because it uses the norm of the CFG result as the target norm. It should instead use the norm of the unconditioned (positive) prediction.

                    elif apply_cfg:
                        with Tracer("transformer"):
                            pos_noise_pred = self.transformer(
                                latents,
                                prompt_embeds,
                                timestep,
                                img_ids,
                                txt_ids,
                            )[0]
                        assert negative_prompt_embeds is not None
                        with Tracer("cfg_transformer"):
                            neg_noise_pred = self.transformer(
                                latents,
                                negative_prompt_embeds,
                                timestep,
                                neg_img_ids,
                                neg_txt_ids,
                            )[0]
                        assert guidance_buf is not None
                        noise_pred = self._cfg_combine(
                            pos_noise_pred, neg_noise_pred, guidance_buf
                        )
                        if model_inputs.cfg_normalization:
                            noise_pred = self._cfg_renormalization(
                                pos_noise_pred,
                                noise_pred,
                            )

Comment on lines +981 to +986
cfg_timestep_bufs = [
Buffer.from_dlpack(
np.full((2 * batch_size,), float(t), dtype=np.float32)
).to(device)
for t in transformed
]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating a list of buffers on the host and uploading them to the device inside the execute method is inefficient, as it incurs host-side allocation and device transfer overhead for every denoising step. These buffers should be pre-allocated and uploaded in prepare_inputs, or ideally, generated on-device using graph operations (e.g., slicing and broadcasting the existing all_timesteps buffer).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant