@@ -28,10 +28,8 @@ class Qwen35MaxTextToVLLMConverter(BaseMaxTextToVLLMConverter):
2828 NUM_SLOTS = 4 # 3 GDN layers + 1 Full Attention layer per cycle
2929
3030 def convert (self , model_state : dict ):
31- """Main entry point for the Tunix weight synchronization."""
3231 logging .info ("\n %sStarting Qwen 3.5 Conversion (Hybrid 3:1 MoE)...%s" , GREEN , RESET )
3332 self .vllm_state = {}
34-
3533 self .num_reps = self .num_layers // self .NUM_SLOTS
3634
3735 with timer ("Convert Global Weights" ):
@@ -43,45 +41,27 @@ def convert(self, model_state: dict):
4341 with timer ("Convert MoE Weights" ):
4442 self ._convert_moe (model_state )
4543
46- # ------------------------------------------------------------------ #
47- # Protect JAX compilation
48- # ------------------------------------------------------------------ #
44+ # Protect JAX compilation by enforcing bfloat16
4945 for key in self .vllm_state :
5046 self .vllm_state [key ] = self .vllm_state [key ].astype (jnp .bfloat16 )
5147
5248 return self .vllm_state
5349
54- # ------------------------------------------------------------------ #
55- # 1. Global Weights
56- # ------------------------------------------------------------------ #
5750 def _convert_global (self , params ):
58- logging .info ("_convert_global: Processing embeddings and LM head..." )
59-
6051 self .vllm_state ["vllm_model.language_model.model.embed_tokens.weight" ] = jnp .array (
6152 params ["base" ]["token_embedder" ]["embedding" ]
6253 )
63-
6454 self .vllm_state ["vllm_model.language_model.model.norm.weight" ] = jnp .array (
6555 params ["base" ]["decoder" ]["decoder_norm" ]["scale" ]
6656 )
67-
6857 self .vllm_state ["vllm_model.language_model.lm_head.weight" ] = jnp .transpose (
6958 params ["base" ]["decoder" ]["logits_dense" ]["kernel" ], (1 , 0 )
7059 )
7160
72- # ------------------------------------------------------------------ #
73- # 2. Hybrid Attention (Scanned 3:1 Blocks)
74- # ------------------------------------------------------------------ #
7561 def _convert_attn (self , params ):
76- logging .info ("_convert_attn: Unstacking layer norms and routing hybrid attention..." )
7762 decoder = params ["base" ]["decoder" ]
78-
79- if "scanned_blocks" in decoder :
80- blocks = decoder ["scanned_blocks" ]
81- slot_prefix = "layers"
82- else :
83- blocks = decoder ["layers" ]
84- slot_prefix = "layer"
63+ blocks = decoder .get ("scanned_blocks" , decoder .get ("layers" ))
64+ slot_prefix = "layers" if "scanned_blocks" in decoder else "layer"
8565
8666 @jax .jit
8767 def _unstack_rep (x ):
@@ -114,20 +94,18 @@ def _unstack_rep(x):
11494
11595 q , k , v = q_layers [rep ], k_layers [rep ], v_layers [rep ]
11696
117- # Transpose to standard (num_heads, head_dim, emb_dim)
11897 q_T = jnp .transpose (q , (1 , 2 , 0 ))
11998 k_T = jnp .transpose (k , (1 , 2 , 0 ))
12099 v_T = jnp .transpose (v , (1 , 2 , 0 ))
121100
122- # Flatten head dimensions and slice for TP interleaving
123101 tp_size = self .vllm_tp
124102 q_tp_shards = jnp .split (q_T .reshape (- 1 , q .shape [0 ]), tp_size , axis = 0 )
125103 k_tp_shards = jnp .split (k_T .reshape (- 1 , k .shape [0 ]), tp_size , axis = 0 )
126104 v_tp_shards = jnp .split (v_T .reshape (- 1 , v .shape [0 ]), tp_size , axis = 0 )
127105
128- tp_interleaved = []
129- for t in range (tp_size ):
130- tp_interleaved . append ( jnp . concatenate ([ q_tp_shards [ t ], k_tp_shards [ t ], v_tp_shards [ t ]], axis = 0 ))
106+ tp_interleaved = [
107+ jnp . concatenate ([ q_tp_shards [ t ], k_tp_shards [ t ], v_tp_shards [ t ]], axis = 0 ) for t in range (tp_size )
108+ ]
131109
132110 self .vllm_state [f"{ prefix } .self_attn.qkv_proj.weight" ] = jnp .concatenate (tp_interleaved , axis = 0 )
133111 self .vllm_state [f"{ prefix } .self_attn.o_proj.weight" ] = jnp .transpose (o_layers [rep ], (1 , 0 ))
@@ -136,11 +114,9 @@ def _unstack_rep(x):
136114
137115 else :
138116 gdn = slot_data ["attention" ]
139-
140117 qkvz_layers = jnp .unstack (gdn ["in_proj_qkvz" ]["kernel" ], axis = 1 )
141118 ba_layers = jnp .unstack (gdn ["in_proj_ba" ]["kernel" ], axis = 1 )
142119 out_layers = jnp .unstack (gdn ["out_proj" ]["kernel" ], axis = 1 )
143-
144120 conv_layers = jnp .unstack (gdn ["conv1d" ]["kernel" ], axis = 1 )
145121
146122 A_log_layers = jnp .unstack (gdn ["A_log" ], axis = 1 )
@@ -154,84 +130,55 @@ def _unstack_rep(x):
154130 self .vllm_state [f"{ prefix } .input_layernorm.weight" ] = pre_ln [rep ]
155131 self .vllm_state [f"{ prefix } .post_attention_layernorm.weight" ] = post_ln [rep ]
156132
157- # Extract MaxText QKVZ layout
158- H_k = 16
159- H_v = 32
160- D_k = 128
161- D_v = 128
162- V_per_K = 2
133+ # Extract MaxText GDN QKVZ Layout
134+ H_k , H_v , D_k , D_v , V_per_K = 16 , 32 , 128 , 128 , 2
163135
164136 t_m = jnp .transpose (qkvz_layers [rep ], (1 , 0 ))
165137 block_size = D_k + D_k + V_per_K * D_v + V_per_K * D_v
166138 t_r = t_m .reshape (H_k , block_size , - 1 )
167139
168- q_r = t_r [:, :D_k , :]
169- k_r = t_r [:, D_k : 2 * D_k , :]
170- v_r = t_r [:, 2 * D_k : 2 * D_k + V_per_K * D_v , :]
171- z_r = t_r [:, 2 * D_k + V_per_K * D_v :, :]
172-
173- q = q_r .reshape (H_k * D_k , - 1 )
174- k = k_r .reshape (H_k * D_k , - 1 )
175- v = v_r .reshape (H_v * D_v , - 1 )
176- z = z_r .reshape (H_v * D_v , - 1 )
140+ q = t_r [:, :D_k , :].reshape (H_k * D_k , - 1 )
141+ k = t_r [:, D_k : 2 * D_k , :].reshape (H_k * D_k , - 1 )
142+ v = t_r [:, 2 * D_k : 2 * D_k + V_per_K * D_v , :].reshape (H_v * D_v , - 1 )
143+ z = t_r [:, 2 * D_k + V_per_K * D_v :, :].reshape (H_v * D_v , - 1 )
177144
178- # Interleave GDN QKVZ by Tensor Parallel shard
179145 tp_size = self .vllm_tp
180146 q_shards = jnp .split (q , tp_size , axis = 0 )
181147 k_shards = jnp .split (k , tp_size , axis = 0 )
182148 v_shards = jnp .split (v , tp_size , axis = 0 )
183149 z_shards = jnp .split (z , tp_size , axis = 0 )
184150
185- qkvz_interleaved_shards = []
186- for s in range (tp_size ):
187- qkvz_interleaved_shards .append (jnp .concatenate ([q_shards [s ], k_shards [s ], v_shards [s ], z_shards [s ]], axis = 0 ))
151+ qkvz_interleaved = [
152+ jnp .concatenate ([q_shards [s ], k_shards [s ], v_shards [s ], z_shards [s ]], axis = 0 ) for s in range (tp_size )
153+ ]
154+ self .vllm_state [f"{ prefix } .linear_attn.in_proj_qkvz.weight" ] = jnp .concatenate (qkvz_interleaved , axis = 0 )
188155
189- self .vllm_state [f"{ prefix } .linear_attn.in_proj_qkvz.weight" ] = jnp .concatenate (qkvz_interleaved_shards , axis = 0 )
190-
191- # Extract MaxText BA layout
156+ # Extract MaxText GDN BA Layout
192157 t_m_ba = jnp .transpose (ba_layers [rep ], (1 , 0 ))
193158 block_size_ba = V_per_K * 2
194159 t_r_ba = t_m_ba .reshape (H_k , block_size_ba , - 1 )
195160
196- b_r = t_r_ba [:, :V_per_K , :]
197- a_r = t_r_ba [:, V_per_K :, :]
198-
199- b = b_r .reshape (H_v , - 1 )
200- a = a_r .reshape (H_v , - 1 )
161+ b = t_r_ba [:, :V_per_K , :].reshape (H_v , - 1 )
162+ a = t_r_ba [:, V_per_K :, :].reshape (H_v , - 1 )
201163
202- # Interleave BA vectors by Tensor Parallel shard
203164 b_shards = jnp .split (b , tp_size , axis = 0 )
204165 a_shards = jnp .split (a , tp_size , axis = 0 )
205166
206- ba_interleaved_shards = []
207- for s in range (tp_size ):
208- ba_interleaved_shards .append (jnp .concatenate ([b_shards [s ], a_shards [s ]], axis = 0 ))
167+ ba_interleaved = [jnp .concatenate ([b_shards [s ], a_shards [s ]], axis = 0 ) for s in range (tp_size )]
168+ self .vllm_state [f"{ prefix } .linear_attn.in_proj_ba.weight" ] = jnp .concatenate (ba_interleaved , axis = 0 )
209169
210- self .vllm_state [f"{ prefix } .linear_attn.in_proj_ba.weight" ] = jnp .concatenate (ba_interleaved_shards , axis = 0 )
211170 self .vllm_state [f"{ prefix } .linear_attn.out_proj.weight" ] = jnp .transpose (out_layers [rep ], (1 , 0 ))
212-
213- # MT: [K, 1, C] <-> HF: [C, 1, K]
214- conv_w = conv_layers [rep ]
215- self .vllm_state [f"{ prefix } .linear_attn.conv1d.weight" ] = jnp .transpose (conv_w , (2 , 1 , 0 ))
171+ self .vllm_state [f"{ prefix } .linear_attn.conv1d.weight" ] = jnp .transpose (conv_layers [rep ], (2 , 1 , 0 ))
216172 self .vllm_state [f"{ prefix } .linear_attn.A_log" ] = A_log_layers [rep ]
217173 self .vllm_state [f"{ prefix } .linear_attn.dt_bias" ] = dt_bias_layers [rep ]
218174 self .vllm_state [f"{ prefix } .linear_attn.norm.weight" ] = gdn_norm_layers [rep ]
219175
220176 gc .collect ()
221177
222- # ------------------------------------------------------------------ #
223- # 3. Mixture of Experts (Scanned Block)
224- # ------------------------------------------------------------------ #
225178 def _convert_moe (self , params ):
226- logging .info ("_convert_moe: Packaging routed and shared experts..." )
227179 decoder = params ["base" ]["decoder" ]
228-
229- if "scanned_blocks" in decoder :
230- blocks = decoder ["scanned_blocks" ]
231- slot_prefix = "layers"
232- else :
233- blocks = decoder ["layers" ]
234- slot_prefix = "layer"
180+ blocks = decoder .get ("scanned_blocks" , decoder .get ("layers" ))
181+ slot_prefix = "layers" if "scanned_blocks" in decoder else "layer"
235182
236183 for slot in range (self .NUM_SLOTS ):
237184 slot_data = blocks [f"{ slot_prefix } _{ slot } " ]
@@ -245,21 +192,35 @@ def _convert_moe(self, params):
245192
246193 router_weights = jnp .unstack (jnp .transpose (routed ["gate" ]["kernel" ], (1 , 2 , 0 )), axis = 0 )
247194
248- # Fusing and Tensor Parallel Interleaving for MoE W1 and W3
195+ # -------------------------------------------------------------
196+ # Fusing, TP Interleaving, and TPU GMM Alignment for W1 and W3
197+ # -------------------------------------------------------------
249198 wi_0 = jnp .transpose (routed ["wi_0" ], (1 , 0 , 2 , 3 ))
250199 wi_1 = jnp .transpose (routed ["wi_1" ], (1 , 0 , 2 , 3 ))
251200
201+ num_reps , num_experts , d_model , d_inner = wi_0 .shape
252202 tp_size = self .vllm_tp
253- w1_shards = jnp .split (wi_0 , tp_size , axis = - 1 )
254- w3_shards = jnp .split (wi_1 , tp_size , axis = - 1 )
255203
256- interleaved_shards = []
257- for i in range (tp_size ):
258- interleaved_shards .append (w1_shards [i ])
259- interleaved_shards .append (w3_shards [i ])
204+ # vLLM's TPU Grouped GEMM kernel requires 128-alignment per expert chunk
205+ chunk_size = d_inner // tp_size
206+ padded_chunk_size = ((chunk_size + 127 ) // 128 ) * 128
207+ pad_amount = padded_chunk_size - chunk_size
208+
209+ w1_chunks = wi_0 .reshape (num_reps , num_experts , d_model , tp_size , chunk_size )
210+ w3_chunks = wi_1 .reshape (num_reps , num_experts , d_model , tp_size , chunk_size )
211+
212+ # Apply padding if running on a topology that splinters chunks below 128 (e.g. TP=8)
213+ if pad_amount > 0 :
214+ w1_chunks = jnp .pad (w1_chunks , ((0 , 0 ), (0 , 0 ), (0 , 0 ), (0 , 0 ), (0 , pad_amount )))
215+ w3_chunks = jnp .pad (w3_chunks , ((0 , 0 ), (0 , 0 ), (0 , 0 ), (0 , 0 ), (0 , pad_amount )))
216+
217+ # Interleave W1 and W3 shards -> Shape: (reps, exp, d_model, tp, 2, padded_chunk)
218+ combined_shards = jnp .stack ([w1_chunks , w3_chunks ], axis = - 2 )
260219
261- gate_up = jnp .concatenate (interleaved_shards , axis = - 1 )
220+ # Flatten the TP, 2, and chunk dimensions back into the final inner dimension
221+ gate_up = combined_shards .reshape (num_reps , num_experts , d_model , - 1 )
262222 w13_layers = jnp .unstack (gate_up , axis = 0 )
223+ # -------------------------------------------------------------
263224
264225 wo_transposed = jnp .transpose (routed ["wo" ], (1 , 0 , 2 , 3 ))
265226 down_layers = jnp .unstack (wo_transposed , axis = 0 )
@@ -282,7 +243,6 @@ def _convert_moe(self, params):
282243 self .vllm_state [f"{ p } .mlp.experts.w13_weight" ] = w13_layers [rep ]
283244 self .vllm_state [f"{ p } .mlp.experts.w2_weight" ] = down_layers [rep ]
284245
285- # Build Shared Expert structure
286246 if has_shared :
287247 sh_g , sh_u = sh_gate_layers [rep ], sh_up_layers [rep ]
288248 sh_per_tp = sh_g .shape [0 ] // self .vllm_tp
0 commit comments