Skip to content

Commit 3fb9e2d

Browse files
committed
[MAX] Add Wan transformer model with block-level compilation
## Summary Add the Wan DiT (Diffusion Transformer) model with block-level compilation for memory-efficient inference. ## Description - Implements the full Wan transformer architecture: patch embedding, RoPE 3D positional encoding, self-attention, cross-attention, and adaptive LayerNorm - Uses **block-level compilation**: each of the 40 transformer blocks is compiled as a separate graph sharing the same compiled program, so only one block's activation workspace is live at a time - Block graphs use **symbolic seq_len** for resolution flexibility (480p/720p without recompilation) - Pre/post processing graphs use symbolic spatial dims - Supports diffusers weight key remapping and QKV fusion - Includes 3D RoPE computation matching the Wan frequency schedule ### Memory strategy The block-level approach keeps peak VRAM low (~6.5 GB per block execution vs ~18.5 GB for a single monolithic graph). This is critical for running the 14B parameter model at 720p (seq_len=75,600) on a single GPU. **Note:** `MODULAR_DEVICE_CONTEXT_MEMORY_MANAGER_CHUNK_PERCENT=100` is required at 720p to avoid memory manager fragmentation with symbolic dims. ## Dependencies Should be merged **after** modular#6300 (VAE, for autoencoder restructuring). ## Checklist - [x] PR is small and focused - [x] I ran `./bazelw run format` to format my changes Assisted-by: Claude Code Assisted-by: Claude Code stack-info: PR: #16, branch: jglee-sqbits/stack/4
1 parent b387126 commit 3fb9e2d

4 files changed

Lines changed: 1743 additions & 0 deletions

File tree

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
from __future__ import annotations
15+
16+
import math
17+
from collections.abc import Callable
18+
from typing import Any
19+
20+
from max.dtype import DType
21+
from max.graph import DeviceRef, TensorValue, ops
22+
from max.nn.activation import activation_function_from_name
23+
from max.nn.layer import Module
24+
from max.nn.linear import Linear
25+
26+
27+
def get_timestep_embedding(
28+
timesteps: TensorValue,
29+
embedding_dim: int,
30+
flip_sin_to_cos: bool = False,
31+
downscale_freq_shift: float = 1,
32+
scale: float = 1,
33+
max_period: int = 10000,
34+
) -> TensorValue:
35+
half_dim = embedding_dim // 2
36+
exponent = -math.log(max_period) * ops.range(
37+
0, half_dim, dtype=DType.float32, device=timesteps.device
38+
)
39+
exponent = exponent / (half_dim - downscale_freq_shift)
40+
emb = ops.exp(exponent)
41+
timesteps_expanded = ops.cast(ops.unsqueeze(timesteps, 1), DType.float32)
42+
emb_expanded = ops.unsqueeze(emb, 0)
43+
emb = scale * timesteps_expanded * emb_expanded
44+
emb = ops.concat([ops.sin(emb), ops.cos(emb)], axis=-1)
45+
if flip_sin_to_cos:
46+
emb = ops.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
47+
if embedding_dim % 2 == 1:
48+
emb = ops.pad(emb, (0, 0, 0, 1))
49+
return emb
50+
51+
52+
def apply_rotary_emb(
53+
x: TensorValue,
54+
freqs_cis: tuple[TensorValue, TensorValue],
55+
use_real: bool = True,
56+
use_real_unbind_dim: int = -1,
57+
sequence_dim: int = 2,
58+
) -> TensorValue:
59+
if not use_real:
60+
raise NotImplementedError("Only use_real=True is supported")
61+
62+
cos, sin = freqs_cis
63+
if sequence_dim == 2:
64+
cos = ops.unsqueeze(ops.unsqueeze(cos, 0), 0)
65+
sin = ops.unsqueeze(ops.unsqueeze(sin, 0), 0)
66+
elif sequence_dim == 1:
67+
cos = ops.unsqueeze(ops.unsqueeze(cos, 0), 2)
68+
sin = ops.unsqueeze(ops.unsqueeze(sin, 0), 2)
69+
else:
70+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
71+
72+
input_dtype = x.dtype
73+
74+
if use_real_unbind_dim == -1:
75+
x_shape: list[Any] = list(x.shape)
76+
new_shape: list[Any] = x_shape[:-1] + [x_shape[-1] // 2, 2]
77+
x_reshaped = ops.reshape(x, new_shape)
78+
x_real = x_reshaped[..., 0]
79+
x_imag = x_reshaped[..., 1]
80+
x_rotated_stacked = ops.stack([-x_imag, x_real], axis=-1)
81+
x_rotated = ops.reshape(x_rotated_stacked, x_shape)
82+
elif use_real_unbind_dim == -2:
83+
x_shape = list(x.shape)
84+
new_shape = x_shape[:-1] + [2, x_shape[-1] // 2]
85+
x_reshaped = ops.reshape(x, new_shape)
86+
x_real = x_reshaped[..., 0, :]
87+
x_imag = x_reshaped[..., 1, :]
88+
x_rotated = ops.concat([-x_imag, x_real], axis=-1)
89+
else:
90+
raise ValueError(
91+
f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2."
92+
)
93+
94+
out = ops.cast(x, DType.float32) * ops.cast(cos, DType.float32) + ops.cast(
95+
x_rotated, DType.float32
96+
) * ops.cast(sin, DType.float32)
97+
return ops.cast(out, input_dtype)
98+
99+
100+
class Timesteps(Module):
101+
def __init__(
102+
self,
103+
num_channels: int,
104+
flip_sin_to_cos: bool,
105+
downscale_freq_shift: float,
106+
scale: float = 1,
107+
):
108+
super().__init__()
109+
self.num_channels = num_channels
110+
self.flip_sin_to_cos = flip_sin_to_cos
111+
self.downscale_freq_shift = downscale_freq_shift
112+
self.scale = scale
113+
114+
def __call__(self, timesteps: TensorValue) -> TensorValue:
115+
return get_timestep_embedding(
116+
timesteps,
117+
self.num_channels,
118+
flip_sin_to_cos=self.flip_sin_to_cos,
119+
downscale_freq_shift=self.downscale_freq_shift,
120+
scale=self.scale,
121+
)
122+
123+
124+
class TimestepEmbedding(Module):
125+
def __init__(
126+
self,
127+
in_channels: int,
128+
time_embed_dim: int,
129+
act_fn: str = "silu",
130+
out_dim: int | None = None,
131+
post_act_fn: str | None = None,
132+
cond_proj_dim: int | None = None,
133+
sample_proj_bias: bool = True,
134+
*,
135+
dtype: DType = DType.bfloat16,
136+
device: DeviceRef = DeviceRef.CPU(),
137+
):
138+
super().__init__()
139+
self.linear_1 = Linear(
140+
in_dim=in_channels,
141+
out_dim=time_embed_dim,
142+
dtype=dtype,
143+
device=device,
144+
has_bias=sample_proj_bias,
145+
)
146+
self.cond_proj: Linear | None
147+
if cond_proj_dim is not None:
148+
self.cond_proj = Linear(
149+
in_dim=cond_proj_dim,
150+
out_dim=in_channels,
151+
dtype=dtype,
152+
device=device,
153+
has_bias=False,
154+
)
155+
else:
156+
self.cond_proj = None
157+
self.act = activation_function_from_name(act_fn)
158+
time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
159+
self.linear_2 = Linear(
160+
in_dim=time_embed_dim,
161+
out_dim=time_embed_dim_out,
162+
dtype=dtype,
163+
device=device,
164+
has_bias=sample_proj_bias,
165+
)
166+
self.post_act: Callable[[TensorValue], TensorValue] | None
167+
if post_act_fn is not None:
168+
self.post_act = activation_function_from_name(post_act_fn)
169+
else:
170+
self.post_act = None
171+
172+
def __call__(self, sample: TensorValue) -> TensorValue:
173+
if self.cond_proj is not None:
174+
sample = sample + self.cond_proj(sample)
175+
sample = self.linear_1(sample)
176+
sample = self.act(sample)
177+
sample = self.linear_2(sample)
178+
if self.post_act is not None:
179+
sample = self.post_act(sample)
180+
return sample

0 commit comments

Comments
 (0)