Skip to content

[Pipelines] Refactor and optimize Z-Image modulev3 pipeline#20

Open
byungchul-sqzb wants to merge 1 commit into
mainfrom
byungchul-sqzb/stack/1
Open

[Pipelines] Refactor and optimize Z-Image modulev3 pipeline#20
byungchul-sqzb wants to merge 1 commit into
mainfrom
byungchul-sqzb/stack/1

Conversation

@byungchul-sqzb
Copy link
Copy Markdown
Collaborator

[Pipelines] Refactor and optimize Z-Image modulev3 pipeline

  • Fix autoencoder import and image postprocessing
  • Absorb eager F.mul negate into compiled scheduler_step for z_image
  • Add batched CFG for Z-Image modulev3 pipeline
  • Optimize fused decode, scheduler caching, and eager reduction
  • Apply RoPE micro-optimizations

- Fix autoencoder import and image postprocessing
- Absorb eager F.mul negate into compiled scheduler_step for z_image
- Add batched CFG for Z-Image modulev3 pipeline
- Optimize fused decode, scheduler caching, and eager reduction
- Apply RoPE micro-optimizations

stack-info: PR: #20, branch: byungchul-sqzb/stack/1
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Z-Image pipeline to support batched Classifier-Free Guidance (CFG), optimizes the denoising loop, and introduces fused decoding via the _PostprocessAndDecodeKL module. It also updates the RoPE implementation to use interleaved frequencies and simplifies the attention output layer by removing unnecessary ModuleList wrapping. Feedback identifies a bug where a Buffer is sliced as a Tensor, suggests caching CFG timesteps to reduce host-to-device transfers, and recommends reverting manual slicing to F.chunk for better readability in the AdaLN modulation logic.

)
sigmas_host = np.asarray(model_inputs.sigmas, dtype=np.float32)
dt_values = np.ascontiguousarray(sigmas_host[1:] - sigmas_host[:-1])
dts_seq = Buffer.from_dlpack(dt_values).to(device)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

dts_seq is initialized as a Buffer, but it is later sliced and cast to a Tensor at line 1140. Buffer objects do not support slicing operations. It should be wrapped in a Tensor during initialization to allow proper indexing in the denoising loop.

Suggested change
dts_seq = Buffer.from_dlpack(dt_values).to(device)
dts_seq = Tensor(
storage=Buffer.from_dlpack(dt_values).to(device)
)

Comment on lines +1071 to +1082
cfg_timesteps = [
Tensor(
storage=Buffer.from_dlpack(
np.full(
(2 * batch_size,),
float(transformed_timesteps[i]),
dtype=np.float32,
)
).to(device)
)
for i in range(num_timesteps)
]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The cfg_timesteps list is recreated on every execution, which involves num_timesteps host-to-device transfers. Since these values only depend on the timestep schedule and batch size, they should be cached (similar to timestep_scalars) to avoid redundant allocations and transfers.

Comment on lines +98 to +102
d = self.dim
scale_msa = 1.0 + mod[:, :, :d]
gate_msa = F.tanh(mod[:, :, d : 2 * d])
scale_mlp = 1.0 + mod[:, :, 2 * d : 3 * d]
gate_mlp = F.tanh(mod[:, :, 3 * d :])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Replacing F.chunk with manual slicing is less readable and potentially less efficient. F.chunk is the idiomatic way to split a modulation tensor into equal parts for AdaLN parameters, and it clearly expresses the intent to the compiler.

Suggested change
d = self.dim
scale_msa = 1.0 + mod[:, :, :d]
gate_msa = F.tanh(mod[:, :, d : 2 * d])
scale_mlp = 1.0 + mod[:, :, 2 * d : 3 * d]
gate_mlp = F.tanh(mod[:, :, 3 * d :])
scale_msa, gate_msa, scale_mlp, gate_mlp = F.chunk(mod, 4, axis=2)
scale_msa = 1.0 + scale_msa
gate_msa = F.tanh(gate_msa)
scale_mlp = 1.0 + scale_mlp
gate_mlp = F.tanh(gate_mlp)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant