4040BlockSizes = common_types .BlockSizes
4141
4242
43- def get_frequencies (max_seq_len : int , theta : int , attention_head_dim : int ):
43+ def get_frequencies (max_seq_len : int , theta : int , attention_head_dim : int , original_attention_head_dim : int ):
4444 h_dim = w_dim = 2 * (attention_head_dim // 6 )
4545 t_dim = attention_head_dim - h_dim - w_dim
46+ current_dims = [t_dim , h_dim , w_dim ]
47+
48+ h_dim_old = w_dim_old = 2 * (original_attention_head_dim // 6 )
49+ t_dim_old = original_attention_head_dim - h_dim_old - w_dim_old
50+ old_dims = [t_dim_old , h_dim_old , w_dim_old ]
51+
4652 freqs = []
47- for dim in [t_dim , h_dim , w_dim ]:
48- freq = get_1d_rotary_pos_embed (dim , max_seq_len , theta , freqs_dtype = jnp .float32 , use_real = False )
53+ for dim , old_dim in zip (current_dims , old_dims ):
54+ freq = get_1d_rotary_pos_embed (
55+ dim = dim , pos = max_seq_len , theta = theta , freqs_dtype = jnp .float32 , use_real = False , original_dim = old_dim
56+ )
4957 freqs .append (freq )
5058 freqs = jnp .concatenate (freqs , axis = 1 )
5159 t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6 )
@@ -62,8 +70,16 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int):
6270
6371class WanRotaryPosEmbed (nnx .Module ):
6472
65- def __init__ (self , attention_head_dim : int , patch_size : Tuple [int , int , int ], max_seq_len : int , theta : float = 10000.0 ):
73+ def __init__ (
74+ self ,
75+ attention_head_dim : int ,
76+ original_attention_head_dim : int ,
77+ patch_size : Tuple [int , int , int ],
78+ max_seq_len : int ,
79+ theta : float = 10000.0 ,
80+ ):
6681 self .attention_head_dim = attention_head_dim
82+ self .original_attention_head_dim = original_attention_head_dim
6783 self .patch_size = patch_size
6884 self .max_seq_len = max_seq_len
6985 self .theta = theta
@@ -73,7 +89,7 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
7389 p_t , p_h , p_w = self .patch_size
7490 ppf , pph , ppw = num_frames // p_t , height // p_h , width // p_w
7591
76- freqs_split = get_frequencies (self .max_seq_len , self .theta , self .attention_head_dim )
92+ freqs_split = get_frequencies (self .max_seq_len , self .theta , self .attention_head_dim , self . original_attention_head_dim )
7793
7894 freqs_f = jnp .expand_dims (jnp .expand_dims (freqs_split [0 ][:ppf ], axis = 1 ), axis = 1 )
7995 freqs_f = jnp .broadcast_to (freqs_f , (ppf , pph , ppw , freqs_split [0 ].shape [- 1 ]))
@@ -494,15 +510,16 @@ def __init__(
494510 enable_jax_named_scopes : bool = False ,
495511 use_base2_exp : bool = False ,
496512 use_experimental_scheduler : bool = False ,
513+ target_head_dim : int = 128 ,
497514 ):
498- inner_dim = num_attention_heads * attention_head_dim
515+ inner_dim = num_attention_heads * target_head_dim
499516 out_channels = out_channels or in_channels
500517 self .num_layers = num_layers
501518 self .scan_layers = scan_layers
502519 self .enable_jax_named_scopes = enable_jax_named_scopes
503520
504521 # 1. Patch & position embedding
505- self .rope = WanRotaryPosEmbed (attention_head_dim , patch_size , rope_max_seq_len )
522+ self .rope = WanRotaryPosEmbed (target_head_dim , attention_head_dim , patch_size , rope_max_seq_len )
506523 self .patch_embedding = nnx .Conv (
507524 in_channels ,
508525 inner_dim ,
0 commit comments