Skip to content

Commit f1e8eea

Browse files
committed
Apply head dimension [256, 20] surgery support to WAN 2.2
1 parent 71b4138 commit f1e8eea

4 files changed

Lines changed: 72 additions & 8 deletions

File tree

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,9 @@ enable_ssim: False
389389
enable_ml_diagnostics: False
390390
profiler_gcs_path: ""
391391
enable_ondemand_xprof: False
392+
393+
# Model surgery parameters
394+
override_model_dims: True
395+
target_head_dim: 256
396+
target_num_heads: 20
397+

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def get_1d_rotary_pos_embed(
226226
ntk_factor=1.0,
227227
freqs_dtype=jnp.float32,
228228
use_real: bool = True,
229+
original_dim: Optional[int] = None,
229230
):
230231
"""
231232
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -236,7 +237,8 @@ def get_1d_rotary_pos_embed(
236237
pos = jnp.arange(pos)
237238

238239
theta = theta * ntk_factor
239-
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor
240+
scale_dim = original_dim if original_dim is not None else dim
241+
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / scale_dim)) / linear_factor
240242
freqs = jnp.outer(pos, freqs)
241243
if use_real:
242244
# Flux

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,20 @@
4040
BlockSizes = common_types.BlockSizes
4141

4242

43-
def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int):
43+
def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int, original_attention_head_dim: int):
4444
h_dim = w_dim = 2 * (attention_head_dim // 6)
4545
t_dim = attention_head_dim - h_dim - w_dim
46+
current_dims = [t_dim, h_dim, w_dim]
47+
48+
h_dim_old = w_dim_old = 2 * (original_attention_head_dim // 6)
49+
t_dim_old = original_attention_head_dim - h_dim_old - w_dim_old
50+
old_dims = [t_dim_old, h_dim_old, w_dim_old]
51+
4652
freqs = []
47-
for dim in [t_dim, h_dim, w_dim]:
48-
freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float32, use_real=False)
53+
for dim, old_dim in zip(current_dims, old_dims):
54+
freq = get_1d_rotary_pos_embed(
55+
dim=dim, pos=max_seq_len, theta=theta, freqs_dtype=jnp.float32, use_real=False, original_dim=old_dim
56+
)
4957
freqs.append(freq)
5058
freqs = jnp.concatenate(freqs, axis=1)
5159
t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6)
@@ -62,8 +70,16 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int):
6270

6371
class WanRotaryPosEmbed(nnx.Module):
6472

65-
def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0):
73+
def __init__(
74+
self,
75+
attention_head_dim: int,
76+
original_attention_head_dim: int,
77+
patch_size: Tuple[int, int, int],
78+
max_seq_len: int,
79+
theta: float = 10000.0,
80+
):
6681
self.attention_head_dim = attention_head_dim
82+
self.original_attention_head_dim = original_attention_head_dim
6783
self.patch_size = patch_size
6884
self.max_seq_len = max_seq_len
6985
self.theta = theta
@@ -73,7 +89,7 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
7389
p_t, p_h, p_w = self.patch_size
7490
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
7591

76-
freqs_split = get_frequencies(self.max_seq_len, self.theta, self.attention_head_dim)
92+
freqs_split = get_frequencies(self.max_seq_len, self.theta, self.attention_head_dim, self.original_attention_head_dim)
7793

7894
freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1)
7995
freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1]))
@@ -494,15 +510,16 @@ def __init__(
494510
enable_jax_named_scopes: bool = False,
495511
use_base2_exp: bool = False,
496512
use_experimental_scheduler: bool = False,
513+
target_head_dim: int = 128,
497514
):
498-
inner_dim = num_attention_heads * attention_head_dim
515+
inner_dim = num_attention_heads * target_head_dim
499516
out_channels = out_channels or in_channels
500517
self.num_layers = num_layers
501518
self.scan_layers = scan_layers
502519
self.enable_jax_named_scopes = enable_jax_named_scopes
503520

504521
# 1. Patch & position embedding
505-
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
522+
self.rope = WanRotaryPosEmbed(target_head_dim, attention_head_dim, patch_size, rope_max_seq_len)
506523
self.patch_embedding = nnx.Conv(
507524
in_channels,
508525
inner_dim,

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,38 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl
9999
return vs
100100

101101

102+
from flax.traverse_util import flatten_dict, unflatten_dict
103+
104+
105+
def perform_wan_scaling_surgery(params, target_head_dim, source_head_dim):
106+
"""
107+
scales Q and K weights to preserve attention entropy when
108+
changing head dimensions.
109+
110+
Formula: correction_factor = (target_dim / source_dim)^0.25
111+
"""
112+
if target_head_dim == source_head_dim:
113+
print("Target and Source head dims are identical. Skipping surgery.")
114+
return params
115+
116+
ratio = target_head_dim / source_head_dim
117+
correction_factor = ratio**0.25
118+
119+
flat_params = flatten_dict(params, sep="/")
120+
new_flat_params = {}
121+
modified_count = 0
122+
123+
for key, tensor in flat_params.items():
124+
if ("query" in key or "key" in key) and "kernel" in key and "attn" in key:
125+
new_flat_params[key] = tensor * correction_factor
126+
modified_count += 1
127+
else:
128+
new_flat_params[key] = tensor
129+
130+
print(f"Surgery complete. Scaled {modified_count} tensors by {correction_factor:.4f}")
131+
return unflatten_dict(new_flat_params, sep="/")
132+
133+
102134
# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device.
103135
def create_sharded_logical_transformer(
104136
devices_array: np.array,
@@ -141,6 +173,11 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
141173
wan_config["use_base2_exp"] = config.use_base2_exp
142174
wan_config["use_experimental_scheduler"] = config.use_experimental_scheduler
143175

176+
wan_config["target_head_dim"] = wan_config["attention_head_dim"]
177+
if getattr(config, "override_model_dims", False):
178+
wan_config["target_head_dim"] = config.target_head_dim
179+
wan_config["num_attention_heads"] = config.target_num_heads
180+
144181
# 2. eval_shape - will not use flops or create weights on device
145182
# thus not using HBM memory.
146183
p_model_factory = partial(create_model, wan_config=wan_config)
@@ -171,6 +208,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
171208
scan_layers=config.scan_layers,
172209
subfolder=subfolder,
173210
)
211+
if getattr(config, "override_model_dims", False):
212+
params = perform_wan_scaling_surgery(params, config.target_head_dim, wan_config["attention_head_dim"])
174213

175214
params = jax.tree_util.tree_map_with_path(
176215
lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params

0 commit comments

Comments
 (0)