Skip to content

Commit b05b211

Browse files
Ran linter
1 parent ce3adf3 commit b05b211

1 file changed

Lines changed: 46 additions & 86 deletions

File tree

src/maxtext/integration/vllm/torchax_converter/qwen35_moe.py

Lines changed: 46 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ class Qwen35MaxTextToVLLMConverter(BaseMaxTextToVLLMConverter):
2828
NUM_SLOTS = 4 # 3 GDN layers + 1 Full Attention layer per cycle
2929

3030
def convert(self, model_state: dict):
31-
"""Main entry point for the Tunix weight synchronization."""
3231
logging.info("\n%sStarting Qwen 3.5 Conversion (Hybrid 3:1 MoE)...%s", GREEN, RESET)
3332
self.vllm_state = {}
34-
3533
self.num_reps = self.num_layers // self.NUM_SLOTS
3634

3735
with timer("Convert Global Weights"):
@@ -43,45 +41,27 @@ def convert(self, model_state: dict):
4341
with timer("Convert MoE Weights"):
4442
self._convert_moe(model_state)
4543

46-
# ------------------------------------------------------------------ #
47-
# Protect JAX compilation
48-
# ------------------------------------------------------------------ #
44+
# Protect JAX compilation by enforcing bfloat16
4945
for key in self.vllm_state:
5046
self.vllm_state[key] = self.vllm_state[key].astype(jnp.bfloat16)
5147

5248
return self.vllm_state
5349

54-
# ------------------------------------------------------------------ #
55-
# 1. Global Weights
56-
# ------------------------------------------------------------------ #
5750
def _convert_global(self, params):
58-
logging.info("_convert_global: Processing embeddings and LM head...")
59-
6051
self.vllm_state["vllm_model.language_model.model.embed_tokens.weight"] = jnp.array(
6152
params["base"]["token_embedder"]["embedding"]
6253
)
63-
6454
self.vllm_state["vllm_model.language_model.model.norm.weight"] = jnp.array(
6555
params["base"]["decoder"]["decoder_norm"]["scale"]
6656
)
67-
6857
self.vllm_state["vllm_model.language_model.lm_head.weight"] = jnp.transpose(
6958
params["base"]["decoder"]["logits_dense"]["kernel"], (1, 0)
7059
)
7160

72-
# ------------------------------------------------------------------ #
73-
# 2. Hybrid Attention (Scanned 3:1 Blocks)
74-
# ------------------------------------------------------------------ #
7561
def _convert_attn(self, params):
76-
logging.info("_convert_attn: Unstacking layer norms and routing hybrid attention...")
7762
decoder = params["base"]["decoder"]
78-
79-
if "scanned_blocks" in decoder:
80-
blocks = decoder["scanned_blocks"]
81-
slot_prefix = "layers"
82-
else:
83-
blocks = decoder["layers"]
84-
slot_prefix = "layer"
63+
blocks = decoder.get("scanned_blocks", decoder.get("layers"))
64+
slot_prefix = "layers" if "scanned_blocks" in decoder else "layer"
8565

8666
@jax.jit
8767
def _unstack_rep(x):
@@ -114,20 +94,18 @@ def _unstack_rep(x):
11494

11595
q, k, v = q_layers[rep], k_layers[rep], v_layers[rep]
11696

117-
# Transpose to standard (num_heads, head_dim, emb_dim)
11897
q_T = jnp.transpose(q, (1, 2, 0))
11998
k_T = jnp.transpose(k, (1, 2, 0))
12099
v_T = jnp.transpose(v, (1, 2, 0))
121100

122-
# Flatten head dimensions and slice for TP interleaving
123101
tp_size = self.vllm_tp
124102
q_tp_shards = jnp.split(q_T.reshape(-1, q.shape[0]), tp_size, axis=0)
125103
k_tp_shards = jnp.split(k_T.reshape(-1, k.shape[0]), tp_size, axis=0)
126104
v_tp_shards = jnp.split(v_T.reshape(-1, v.shape[0]), tp_size, axis=0)
127105

128-
tp_interleaved = []
129-
for t in range(tp_size):
130-
tp_interleaved.append(jnp.concatenate([q_tp_shards[t], k_tp_shards[t], v_tp_shards[t]], axis=0))
106+
tp_interleaved = [
107+
jnp.concatenate([q_tp_shards[t], k_tp_shards[t], v_tp_shards[t]], axis=0) for t in range(tp_size)
108+
]
131109

132110
self.vllm_state[f"{prefix}.self_attn.qkv_proj.weight"] = jnp.concatenate(tp_interleaved, axis=0)
133111
self.vllm_state[f"{prefix}.self_attn.o_proj.weight"] = jnp.transpose(o_layers[rep], (1, 0))
@@ -136,11 +114,9 @@ def _unstack_rep(x):
136114

137115
else:
138116
gdn = slot_data["attention"]
139-
140117
qkvz_layers = jnp.unstack(gdn["in_proj_qkvz"]["kernel"], axis=1)
141118
ba_layers = jnp.unstack(gdn["in_proj_ba"]["kernel"], axis=1)
142119
out_layers = jnp.unstack(gdn["out_proj"]["kernel"], axis=1)
143-
144120
conv_layers = jnp.unstack(gdn["conv1d"]["kernel"], axis=1)
145121

146122
A_log_layers = jnp.unstack(gdn["A_log"], axis=1)
@@ -154,84 +130,55 @@ def _unstack_rep(x):
154130
self.vllm_state[f"{prefix}.input_layernorm.weight"] = pre_ln[rep]
155131
self.vllm_state[f"{prefix}.post_attention_layernorm.weight"] = post_ln[rep]
156132

157-
# Extract MaxText QKVZ layout
158-
H_k = 16
159-
H_v = 32
160-
D_k = 128
161-
D_v = 128
162-
V_per_K = 2
133+
# Extract MaxText GDN QKVZ Layout
134+
H_k, H_v, D_k, D_v, V_per_K = 16, 32, 128, 128, 2
163135

164136
t_m = jnp.transpose(qkvz_layers[rep], (1, 0))
165137
block_size = D_k + D_k + V_per_K * D_v + V_per_K * D_v
166138
t_r = t_m.reshape(H_k, block_size, -1)
167139

168-
q_r = t_r[:, :D_k, :]
169-
k_r = t_r[:, D_k : 2 * D_k, :]
170-
v_r = t_r[:, 2 * D_k : 2 * D_k + V_per_K * D_v, :]
171-
z_r = t_r[:, 2 * D_k + V_per_K * D_v :, :]
172-
173-
q = q_r.reshape(H_k * D_k, -1)
174-
k = k_r.reshape(H_k * D_k, -1)
175-
v = v_r.reshape(H_v * D_v, -1)
176-
z = z_r.reshape(H_v * D_v, -1)
140+
q = t_r[:, :D_k, :].reshape(H_k * D_k, -1)
141+
k = t_r[:, D_k : 2 * D_k, :].reshape(H_k * D_k, -1)
142+
v = t_r[:, 2 * D_k : 2 * D_k + V_per_K * D_v, :].reshape(H_v * D_v, -1)
143+
z = t_r[:, 2 * D_k + V_per_K * D_v :, :].reshape(H_v * D_v, -1)
177144

178-
# Interleave GDN QKVZ by Tensor Parallel shard
179145
tp_size = self.vllm_tp
180146
q_shards = jnp.split(q, tp_size, axis=0)
181147
k_shards = jnp.split(k, tp_size, axis=0)
182148
v_shards = jnp.split(v, tp_size, axis=0)
183149
z_shards = jnp.split(z, tp_size, axis=0)
184150

185-
qkvz_interleaved_shards = []
186-
for s in range(tp_size):
187-
qkvz_interleaved_shards.append(jnp.concatenate([q_shards[s], k_shards[s], v_shards[s], z_shards[s]], axis=0))
151+
qkvz_interleaved = [
152+
jnp.concatenate([q_shards[s], k_shards[s], v_shards[s], z_shards[s]], axis=0) for s in range(tp_size)
153+
]
154+
self.vllm_state[f"{prefix}.linear_attn.in_proj_qkvz.weight"] = jnp.concatenate(qkvz_interleaved, axis=0)
188155

189-
self.vllm_state[f"{prefix}.linear_attn.in_proj_qkvz.weight"] = jnp.concatenate(qkvz_interleaved_shards, axis=0)
190-
191-
# Extract MaxText BA layout
156+
# Extract MaxText GDN BA Layout
192157
t_m_ba = jnp.transpose(ba_layers[rep], (1, 0))
193158
block_size_ba = V_per_K * 2
194159
t_r_ba = t_m_ba.reshape(H_k, block_size_ba, -1)
195160

196-
b_r = t_r_ba[:, :V_per_K, :]
197-
a_r = t_r_ba[:, V_per_K:, :]
198-
199-
b = b_r.reshape(H_v, -1)
200-
a = a_r.reshape(H_v, -1)
161+
b = t_r_ba[:, :V_per_K, :].reshape(H_v, -1)
162+
a = t_r_ba[:, V_per_K:, :].reshape(H_v, -1)
201163

202-
# Interleave BA vectors by Tensor Parallel shard
203164
b_shards = jnp.split(b, tp_size, axis=0)
204165
a_shards = jnp.split(a, tp_size, axis=0)
205166

206-
ba_interleaved_shards = []
207-
for s in range(tp_size):
208-
ba_interleaved_shards.append(jnp.concatenate([b_shards[s], a_shards[s]], axis=0))
167+
ba_interleaved = [jnp.concatenate([b_shards[s], a_shards[s]], axis=0) for s in range(tp_size)]
168+
self.vllm_state[f"{prefix}.linear_attn.in_proj_ba.weight"] = jnp.concatenate(ba_interleaved, axis=0)
209169

210-
self.vllm_state[f"{prefix}.linear_attn.in_proj_ba.weight"] = jnp.concatenate(ba_interleaved_shards, axis=0)
211170
self.vllm_state[f"{prefix}.linear_attn.out_proj.weight"] = jnp.transpose(out_layers[rep], (1, 0))
212-
213-
# MT: [K, 1, C] <-> HF: [C, 1, K]
214-
conv_w = conv_layers[rep]
215-
self.vllm_state[f"{prefix}.linear_attn.conv1d.weight"] = jnp.transpose(conv_w, (2, 1, 0))
171+
self.vllm_state[f"{prefix}.linear_attn.conv1d.weight"] = jnp.transpose(conv_layers[rep], (2, 1, 0))
216172
self.vllm_state[f"{prefix}.linear_attn.A_log"] = A_log_layers[rep]
217173
self.vllm_state[f"{prefix}.linear_attn.dt_bias"] = dt_bias_layers[rep]
218174
self.vllm_state[f"{prefix}.linear_attn.norm.weight"] = gdn_norm_layers[rep]
219175

220176
gc.collect()
221177

222-
# ------------------------------------------------------------------ #
223-
# 3. Mixture of Experts (Scanned Block)
224-
# ------------------------------------------------------------------ #
225178
def _convert_moe(self, params):
226-
logging.info("_convert_moe: Packaging routed and shared experts...")
227179
decoder = params["base"]["decoder"]
228-
229-
if "scanned_blocks" in decoder:
230-
blocks = decoder["scanned_blocks"]
231-
slot_prefix = "layers"
232-
else:
233-
blocks = decoder["layers"]
234-
slot_prefix = "layer"
180+
blocks = decoder.get("scanned_blocks", decoder.get("layers"))
181+
slot_prefix = "layers" if "scanned_blocks" in decoder else "layer"
235182

236183
for slot in range(self.NUM_SLOTS):
237184
slot_data = blocks[f"{slot_prefix}_{slot}"]
@@ -245,21 +192,35 @@ def _convert_moe(self, params):
245192

246193
router_weights = jnp.unstack(jnp.transpose(routed["gate"]["kernel"], (1, 2, 0)), axis=0)
247194

248-
# Fusing and Tensor Parallel Interleaving for MoE W1 and W3
195+
# -------------------------------------------------------------
196+
# Fusing, TP Interleaving, and TPU GMM Alignment for W1 and W3
197+
# -------------------------------------------------------------
249198
wi_0 = jnp.transpose(routed["wi_0"], (1, 0, 2, 3))
250199
wi_1 = jnp.transpose(routed["wi_1"], (1, 0, 2, 3))
251200

201+
num_reps, num_experts, d_model, d_inner = wi_0.shape
252202
tp_size = self.vllm_tp
253-
w1_shards = jnp.split(wi_0, tp_size, axis=-1)
254-
w3_shards = jnp.split(wi_1, tp_size, axis=-1)
255203

256-
interleaved_shards = []
257-
for i in range(tp_size):
258-
interleaved_shards.append(w1_shards[i])
259-
interleaved_shards.append(w3_shards[i])
204+
# vLLM's TPU Grouped GEMM kernel requires 128-alignment per expert chunk
205+
chunk_size = d_inner // tp_size
206+
padded_chunk_size = ((chunk_size + 127) // 128) * 128
207+
pad_amount = padded_chunk_size - chunk_size
208+
209+
w1_chunks = wi_0.reshape(num_reps, num_experts, d_model, tp_size, chunk_size)
210+
w3_chunks = wi_1.reshape(num_reps, num_experts, d_model, tp_size, chunk_size)
211+
212+
# Apply padding if running on a topology that splinters chunks below 128 (e.g. TP=8)
213+
if pad_amount > 0:
214+
w1_chunks = jnp.pad(w1_chunks, ((0, 0), (0, 0), (0, 0), (0, 0), (0, pad_amount)))
215+
w3_chunks = jnp.pad(w3_chunks, ((0, 0), (0, 0), (0, 0), (0, 0), (0, pad_amount)))
216+
217+
# Interleave W1 and W3 shards -> Shape: (reps, exp, d_model, tp, 2, padded_chunk)
218+
combined_shards = jnp.stack([w1_chunks, w3_chunks], axis=-2)
260219

261-
gate_up = jnp.concatenate(interleaved_shards, axis=-1)
220+
# Flatten the TP, 2, and chunk dimensions back into the final inner dimension
221+
gate_up = combined_shards.reshape(num_reps, num_experts, d_model, -1)
262222
w13_layers = jnp.unstack(gate_up, axis=0)
223+
# -------------------------------------------------------------
263224

264225
wo_transposed = jnp.transpose(routed["wo"], (1, 0, 2, 3))
265226
down_layers = jnp.unstack(wo_transposed, axis=0)
@@ -282,7 +243,6 @@ def _convert_moe(self, params):
282243
self.vllm_state[f"{p}.mlp.experts.w13_weight"] = w13_layers[rep]
283244
self.vllm_state[f"{p}.mlp.experts.w2_weight"] = down_layers[rep]
284245

285-
# Build Shared Expert structure
286246
if has_shared:
287247
sh_g, sh_u = sh_gate_layers[rep], sh_up_layers[rep]
288248
sh_per_tp = sh_g.shape[0] // self.vllm_tp

0 commit comments

Comments
 (0)