Skip to content

[MAX] Add Wan transformer model with block-level compilation#16

Draft
jglee-sqbits wants to merge 1 commit into
jglee-sqbits/stack/3from
jglee-sqbits/stack/4
Draft

[MAX] Add Wan transformer model with block-level compilation#16
jglee-sqbits wants to merge 1 commit into
jglee-sqbits/stack/3from
jglee-sqbits/stack/4

Conversation

@jglee-sqbits
Copy link
Copy Markdown
Collaborator

@jglee-sqbits jglee-sqbits commented Apr 1, 2026

Stacked PRs:


[MAX] Add Wan transformer model with block-level compilation

Summary

Add the Wan DiT (Diffusion Transformer) model with block-level compilation for memory-efficient inference.

Description

  • Implements the full Wan transformer architecture: patch embedding, RoPE 3D positional encoding, self-attention, cross-attention, and adaptive LayerNorm
  • Uses block-level compilation: each of the 40 transformer blocks is compiled as a separate graph sharing the same compiled program, so only one block's activation workspace is live at a time
  • Block graphs use symbolic seq_len for resolution flexibility (480p/720p without recompilation)
  • Pre/post processing graphs use symbolic spatial dims
  • Supports diffusers weight key remapping and QKV fusion
  • Includes 3D RoPE computation matching the Wan frequency schedule

Memory strategy

The block-level approach keeps peak VRAM low (~6.5 GB per block execution vs ~18.5 GB for a single monolithic graph). This is critical for running the 14B parameter model at 720p (seq_len=75,600) on a single GPU.

Note: MODULAR_DEVICE_CONTEXT_MEMORY_MANAGER_CHUNK_PERCENT=100 is required at 720p to avoid memory manager fragmentation with symbolic dims.

Dependencies

Should be merged after modular#6300 (VAE, for autoencoder restructuring).

Checklist

  • PR is small and focused
  • I ran ./bazelw run format to format my changes

Assisted-by: Claude Code

Assisted-by: Claude Code

## Summary

Add the Wan DiT (Diffusion Transformer) model with block-level compilation for memory-efficient inference.

## Description

- Implements the full Wan transformer architecture: patch embedding, RoPE 3D positional encoding, self-attention, cross-attention, and adaptive LayerNorm
- Uses **block-level compilation**: each of the 40 transformer blocks is compiled as a separate graph sharing the same compiled program, so only one block's activation workspace is live at a time
- Block graphs use **symbolic seq_len** for resolution flexibility (480p/720p without recompilation)
- Pre/post processing graphs use symbolic spatial dims
- Supports diffusers weight key remapping and QKV fusion
- Includes 3D RoPE computation matching the Wan frequency schedule

### Memory strategy
The block-level approach keeps peak VRAM low (~6.5 GB per block execution vs ~18.5 GB for a single monolithic graph). This is critical for running the 14B parameter model at 720p (seq_len=75,600) on a single GPU.

**Note:** `MODULAR_DEVICE_CONTEXT_MEMORY_MANAGER_CHUNK_PERCENT=100` is required at 720p to avoid memory manager fragmentation with symbolic dims.

## Dependencies

Should be merged **after** modular#6300 (VAE, for autoencoder restructuring).

## Checklist

- [x] PR is small and focused
- [x] I ran `./bazelw run format` to format my changes

Assisted-by: Claude Code

Assisted-by: Claude Code

stack-info: PR: #16, branch: jglee-sqbits/stack/4
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the Wan transformer architecture, providing modules for embeddings, model configuration, and the core transformer blocks. The implementation includes a block-level execution model to minimize peak VRAM usage and custom normalization layers for improved numerical stability. Review feedback identifies opportunities to optimize the patch embedding process by removing redundant tensor permutations and reshapes. Additionally, suggestions were made to improve consistency in data type handling by storing the configured dtype as an instance attribute and using it instead of hardcoded values in the post-processing module.

Comment on lines +685 to +690
hs = ops.permute(hidden_states, [0, 2, 3, 4, 1])
hs = self.patch_embedding(hs)
hs = ops.permute(hs, [0, 4, 1, 2, 3])
seq_len = hs.shape[2] * hs.shape[3] * hs.shape[4]
hs = ops.reshape(hs, [batch_size, self.inner_dim, seq_len])
hs = ops.permute(hs, [0, 2, 1])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The permutations and reshapes here are redundant. Since ops.conv3d in MAX uses the NDHWC layout by default, the output of patch_embedding is already in [batch_size, ppf, pph, ppw, dim] format. This can be directly reshaped to [batch_size, seq_len, dim] without intermediate permutations.

Suggested change
hs = ops.permute(hidden_states, [0, 2, 3, 4, 1])
hs = self.patch_embedding(hs)
hs = ops.permute(hs, [0, 4, 1, 2, 3])
seq_len = hs.shape[2] * hs.shape[3] * hs.shape[4]
hs = ops.reshape(hs, [batch_size, self.inner_dim, seq_len])
hs = ops.permute(hs, [0, 2, 1])
hs = ops.permute(hidden_states, [0, 2, 3, 4, 1])
hs = self.patch_embedding(hs)
hs = ops.reshape(hs, [batch_size, -1, self.inner_dim])

self.inner_dim = dim
self.out_channels = config.out_channels
self.patch_size = config.patch_size

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dtype parameter passed to __init__ should be stored as an instance attribute to ensure consistency in the __call__ method, especially for the final output cast.

Suggested change
self.dtype = dtype

hs,
[batch_size, self.out_channels, ppf * p_t, pph * p_h, ppw * p_w],
)
return ops.cast(hs, DType.bfloat16)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output cast is hardcoded to DType.bfloat16. It should use the dtype provided during initialization to support other precisions (e.g., float32) if configured.

Suggested change
return ops.cast(hs, DType.bfloat16)
return ops.cast(hs, self.dtype)

Comment on lines +859 to +863
hs = ops.permute(hidden_states, [0, 2, 3, 4, 1])
hs = self.patch_embedding(hs)
hs = ops.permute(hs, [0, 4, 1, 2, 3])
hs = ops.reshape(hs, [batch_size, self.inner_dim, ppf * pph * ppw])
hs = ops.permute(hs, [0, 2, 1])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the WanTransformerPreProcess module, the permutations and reshapes here are redundant. The output of patch_embedding can be directly reshaped to the target sequence length.

Suggested change
hs = ops.permute(hidden_states, [0, 2, 3, 4, 1])
hs = self.patch_embedding(hs)
hs = ops.permute(hs, [0, 4, 1, 2, 3])
hs = ops.reshape(hs, [batch_size, self.inner_dim, ppf * pph * ppw])
hs = ops.permute(hs, [0, 2, 1])
hs = ops.permute(hidden_states, [0, 2, 3, 4, 1])
hs = self.patch_embedding(hs)
hs = ops.reshape(hs, [batch_size, -1, self.inner_dim])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant