[MAX] Add Wan transformer model with block-level compilation#16
[MAX] Add Wan transformer model with block-level compilation#16jglee-sqbits wants to merge 1 commit into
Conversation
## Summary Add the Wan DiT (Diffusion Transformer) model with block-level compilation for memory-efficient inference. ## Description - Implements the full Wan transformer architecture: patch embedding, RoPE 3D positional encoding, self-attention, cross-attention, and adaptive LayerNorm - Uses **block-level compilation**: each of the 40 transformer blocks is compiled as a separate graph sharing the same compiled program, so only one block's activation workspace is live at a time - Block graphs use **symbolic seq_len** for resolution flexibility (480p/720p without recompilation) - Pre/post processing graphs use symbolic spatial dims - Supports diffusers weight key remapping and QKV fusion - Includes 3D RoPE computation matching the Wan frequency schedule ### Memory strategy The block-level approach keeps peak VRAM low (~6.5 GB per block execution vs ~18.5 GB for a single monolithic graph). This is critical for running the 14B parameter model at 720p (seq_len=75,600) on a single GPU. **Note:** `MODULAR_DEVICE_CONTEXT_MEMORY_MANAGER_CHUNK_PERCENT=100` is required at 720p to avoid memory manager fragmentation with symbolic dims. ## Dependencies Should be merged **after** modular#6300 (VAE, for autoencoder restructuring). ## Checklist - [x] PR is small and focused - [x] I ran `./bazelw run format` to format my changes Assisted-by: Claude Code Assisted-by: Claude Code stack-info: PR: #16, branch: jglee-sqbits/stack/4
228c993 to
b387126
Compare
15eb852 to
3fb9e2d
Compare
There was a problem hiding this comment.
Code Review
This pull request implements the Wan transformer architecture, providing modules for embeddings, model configuration, and the core transformer blocks. The implementation includes a block-level execution model to minimize peak VRAM usage and custom normalization layers for improved numerical stability. Review feedback identifies opportunities to optimize the patch embedding process by removing redundant tensor permutations and reshapes. Additionally, suggestions were made to improve consistency in data type handling by storing the configured dtype as an instance attribute and using it instead of hardcoded values in the post-processing module.
| hs = ops.permute(hidden_states, [0, 2, 3, 4, 1]) | ||
| hs = self.patch_embedding(hs) | ||
| hs = ops.permute(hs, [0, 4, 1, 2, 3]) | ||
| seq_len = hs.shape[2] * hs.shape[3] * hs.shape[4] | ||
| hs = ops.reshape(hs, [batch_size, self.inner_dim, seq_len]) | ||
| hs = ops.permute(hs, [0, 2, 1]) |
There was a problem hiding this comment.
The permutations and reshapes here are redundant. Since ops.conv3d in MAX uses the NDHWC layout by default, the output of patch_embedding is already in [batch_size, ppf, pph, ppw, dim] format. This can be directly reshaped to [batch_size, seq_len, dim] without intermediate permutations.
| hs = ops.permute(hidden_states, [0, 2, 3, 4, 1]) | |
| hs = self.patch_embedding(hs) | |
| hs = ops.permute(hs, [0, 4, 1, 2, 3]) | |
| seq_len = hs.shape[2] * hs.shape[3] * hs.shape[4] | |
| hs = ops.reshape(hs, [batch_size, self.inner_dim, seq_len]) | |
| hs = ops.permute(hs, [0, 2, 1]) | |
| hs = ops.permute(hidden_states, [0, 2, 3, 4, 1]) | |
| hs = self.patch_embedding(hs) | |
| hs = ops.reshape(hs, [batch_size, -1, self.inner_dim]) |
| self.inner_dim = dim | ||
| self.out_channels = config.out_channels | ||
| self.patch_size = config.patch_size | ||
|
|
| hs, | ||
| [batch_size, self.out_channels, ppf * p_t, pph * p_h, ppw * p_w], | ||
| ) | ||
| return ops.cast(hs, DType.bfloat16) |
| hs = ops.permute(hidden_states, [0, 2, 3, 4, 1]) | ||
| hs = self.patch_embedding(hs) | ||
| hs = ops.permute(hs, [0, 4, 1, 2, 3]) | ||
| hs = ops.reshape(hs, [batch_size, self.inner_dim, ppf * pph * ppw]) | ||
| hs = ops.permute(hs, [0, 2, 1]) |
There was a problem hiding this comment.
Similar to the WanTransformerPreProcess module, the permutations and reshapes here are redundant. The output of patch_embedding can be directly reshaped to the target sequence length.
| hs = ops.permute(hidden_states, [0, 2, 3, 4, 1]) | |
| hs = self.patch_embedding(hs) | |
| hs = ops.permute(hs, [0, 4, 1, 2, 3]) | |
| hs = ops.reshape(hs, [batch_size, self.inner_dim, ppf * pph * ppw]) | |
| hs = ops.permute(hs, [0, 2, 1]) | |
| hs = ops.permute(hidden_states, [0, 2, 3, 4, 1]) | |
| hs = self.patch_embedding(hs) | |
| hs = ops.reshape(hs, [batch_size, -1, self.inner_dim]) |
Stacked PRs:
[MAX] Add Wan transformer model with block-level compilation
Summary
Add the Wan DiT (Diffusion Transformer) model with block-level compilation for memory-efficient inference.
Description
Memory strategy
The block-level approach keeps peak VRAM low (~6.5 GB per block execution vs ~18.5 GB for a single monolithic graph). This is critical for running the 14B parameter model at 720p (seq_len=75,600) on a single GPU.
Note:
MODULAR_DEVICE_CONTEXT_MEMORY_MANAGER_CHUNK_PERCENT=100is required at 720p to avoid memory manager fragmentation with symbolic dims.Dependencies
Should be merged after modular#6300 (VAE, for autoencoder restructuring).
Checklist
./bazelw run formatto format my changesAssisted-by: Claude Code
Assisted-by: Claude Code