[Pipelines] Refactor and optimize Z-Image modulev3 pipeline#20
[Pipelines] Refactor and optimize Z-Image modulev3 pipeline#20byungchul-sqzb wants to merge 1 commit into
Conversation
- Fix autoencoder import and image postprocessing - Absorb eager F.mul negate into compiled scheduler_step for z_image - Add batched CFG for Z-Image modulev3 pipeline - Optimize fused decode, scheduler caching, and eager reduction - Apply RoPE micro-optimizations stack-info: PR: #20, branch: byungchul-sqzb/stack/1
f76185d to
23b1c55
Compare
There was a problem hiding this comment.
Code Review
This pull request refactors the Z-Image pipeline to support batched Classifier-Free Guidance (CFG), optimizes the denoising loop, and introduces fused decoding via the _PostprocessAndDecodeKL module. It also updates the RoPE implementation to use interleaved frequencies and simplifies the attention output layer by removing unnecessary ModuleList wrapping. Feedback identifies a bug where a Buffer is sliced as a Tensor, suggests caching CFG timesteps to reduce host-to-device transfers, and recommends reverting manual slicing to F.chunk for better readability in the AdaLN modulation logic.
| ) | ||
| sigmas_host = np.asarray(model_inputs.sigmas, dtype=np.float32) | ||
| dt_values = np.ascontiguousarray(sigmas_host[1:] - sigmas_host[:-1]) | ||
| dts_seq = Buffer.from_dlpack(dt_values).to(device) |
There was a problem hiding this comment.
dts_seq is initialized as a Buffer, but it is later sliced and cast to a Tensor at line 1140. Buffer objects do not support slicing operations. It should be wrapped in a Tensor during initialization to allow proper indexing in the denoising loop.
| dts_seq = Buffer.from_dlpack(dt_values).to(device) | |
| dts_seq = Tensor( | |
| storage=Buffer.from_dlpack(dt_values).to(device) | |
| ) |
| cfg_timesteps = [ | ||
| Tensor( | ||
| storage=Buffer.from_dlpack( | ||
| np.full( | ||
| (2 * batch_size,), | ||
| float(transformed_timesteps[i]), | ||
| dtype=np.float32, | ||
| ) | ||
| ).to(device) | ||
| ) | ||
| for i in range(num_timesteps) | ||
| ] |
There was a problem hiding this comment.
| d = self.dim | ||
| scale_msa = 1.0 + mod[:, :, :d] | ||
| gate_msa = F.tanh(mod[:, :, d : 2 * d]) | ||
| scale_mlp = 1.0 + mod[:, :, 2 * d : 3 * d] | ||
| gate_mlp = F.tanh(mod[:, :, 3 * d :]) |
There was a problem hiding this comment.
Replacing F.chunk with manual slicing is less readable and potentially less efficient. F.chunk is the idiomatic way to split a modulation tensor into equal parts for AdaLN parameters, and it clearly expresses the intent to the compiler.
| d = self.dim | |
| scale_msa = 1.0 + mod[:, :, :d] | |
| gate_msa = F.tanh(mod[:, :, d : 2 * d]) | |
| scale_mlp = 1.0 + mod[:, :, 2 * d : 3 * d] | |
| gate_mlp = F.tanh(mod[:, :, 3 * d :]) | |
| scale_msa, gate_msa, scale_mlp, gate_mlp = F.chunk(mod, 4, axis=2) | |
| scale_msa = 1.0 + scale_msa | |
| gate_msa = F.tanh(gate_msa) | |
| scale_mlp = 1.0 + scale_mlp | |
| gate_mlp = F.tanh(gate_mlp) |
[Pipelines] Refactor and optimize Z-Image modulev3 pipeline