|
| 1 | +# Copyright 2023–2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Qwen 3.5 MaxText to vLLM Converter (Supports 35B MoE Hybrid Architecture).""" |
| 16 | + |
| 17 | +import gc |
| 18 | +import logging |
| 19 | +import jax |
| 20 | +import jax.numpy as jnp |
| 21 | + |
| 22 | +from maxtext.integration.vllm.torchax_converter.base import BaseMaxTextToVLLMConverter, timer, GREEN, RESET |
| 23 | + |
| 24 | + |
| 25 | +class Qwen35MaxTextToVLLMConverter(BaseMaxTextToVLLMConverter): |
| 26 | + """Converts MaxText Qwen3.5 (Scanned Block) layout to vLLM execution layout.""" |
| 27 | + |
| 28 | + NUM_SLOTS = 4 # 3 GDN layers + 1 Full Attention layer per cycle |
| 29 | + |
| 30 | + def convert(self, model_state: dict): |
| 31 | + logging.info("\n%sStarting Qwen 3.5 Conversion (Hybrid 3:1 MoE)...%s", GREEN, RESET) |
| 32 | + self.vllm_state = {} |
| 33 | + self.num_reps = self.num_layers // self.NUM_SLOTS |
| 34 | + |
| 35 | + with timer("Convert Global Weights"): |
| 36 | + self._convert_global(model_state) |
| 37 | + |
| 38 | + with timer("Convert Hybrid Attention Weights"): |
| 39 | + self._convert_attn(model_state) |
| 40 | + |
| 41 | + with timer("Convert MoE Weights"): |
| 42 | + self._convert_moe(model_state) |
| 43 | + |
| 44 | + # Protect JAX compilation by enforcing bfloat16 |
| 45 | + for key in self.vllm_state: |
| 46 | + self.vllm_state[key] = self.vllm_state[key].astype(jnp.bfloat16) |
| 47 | + |
| 48 | + return self.vllm_state |
| 49 | + |
| 50 | + def _convert_global(self, params): |
| 51 | + self.vllm_state["vllm_model.language_model.model.embed_tokens.weight"] = jnp.array( |
| 52 | + params["base"]["token_embedder"]["embedding"] |
| 53 | + ) |
| 54 | + self.vllm_state["vllm_model.language_model.model.norm.weight"] = jnp.array( |
| 55 | + params["base"]["decoder"]["decoder_norm"]["scale"] |
| 56 | + ) |
| 57 | + self.vllm_state["vllm_model.language_model.lm_head.weight"] = jnp.transpose( |
| 58 | + params["base"]["decoder"]["logits_dense"]["kernel"], (1, 0) |
| 59 | + ) |
| 60 | + |
| 61 | + def _convert_attn(self, params): |
| 62 | + decoder = params["base"]["decoder"] |
| 63 | + blocks = decoder.get("scanned_blocks", decoder.get("layers")) |
| 64 | + slot_prefix = "layers" if "scanned_blocks" in decoder else "layer" |
| 65 | + |
| 66 | + @jax.jit |
| 67 | + def _unstack_rep(x): |
| 68 | + return jnp.unstack(x, axis=1) |
| 69 | + |
| 70 | + for slot in range(self.NUM_SLOTS): |
| 71 | + is_full_attention = slot == 3 |
| 72 | + slot_data = blocks[f"{slot_prefix}_{slot}"] |
| 73 | + |
| 74 | + pre_ln = _unstack_rep(slot_data["input_layernorm"]["scale"]) |
| 75 | + post_ln = _unstack_rep(slot_data["post_attention_layernorm"]["scale"]) |
| 76 | + |
| 77 | + if is_full_attention: |
| 78 | + attn = slot_data["attention"]["attention"] |
| 79 | + |
| 80 | + q_layers = jnp.unstack(jnp.transpose(attn["query"]["kernel"], (1, 0, 2, 3)), axis=0) |
| 81 | + k_layers = jnp.unstack(jnp.transpose(attn["key"]["kernel"], (1, 0, 2, 3)), axis=0) |
| 82 | + v_layers = jnp.unstack(jnp.transpose(attn["value"]["kernel"], (1, 0, 2, 3)), axis=0) |
| 83 | + o_layers = jnp.unstack(attn["out"]["kernel"], axis=1) |
| 84 | + |
| 85 | + qnorm_layers = _unstack_rep(attn["query_norm"]["scale"]) |
| 86 | + knorm_layers = _unstack_rep(attn["key_norm"]["scale"]) |
| 87 | + |
| 88 | + for rep in range(self.num_reps): |
| 89 | + i = rep * self.NUM_SLOTS + slot |
| 90 | + prefix = f"vllm_model.language_model.model.layers.{i}" |
| 91 | + |
| 92 | + self.vllm_state[f"{prefix}.input_layernorm.weight"] = pre_ln[rep] |
| 93 | + self.vllm_state[f"{prefix}.post_attention_layernorm.weight"] = post_ln[rep] |
| 94 | + |
| 95 | + q, k, v = q_layers[rep], k_layers[rep], v_layers[rep] |
| 96 | + |
| 97 | + q_T = jnp.transpose(q, (1, 2, 0)) |
| 98 | + k_T = jnp.transpose(k, (1, 2, 0)) |
| 99 | + v_T = jnp.transpose(v, (1, 2, 0)) |
| 100 | + |
| 101 | + tp_size = self.vllm_tp |
| 102 | + q_tp_shards = jnp.split(q_T.reshape(-1, q.shape[0]), tp_size, axis=0) |
| 103 | + k_tp_shards = jnp.split(k_T.reshape(-1, k.shape[0]), tp_size, axis=0) |
| 104 | + v_tp_shards = jnp.split(v_T.reshape(-1, v.shape[0]), tp_size, axis=0) |
| 105 | + |
| 106 | + tp_interleaved = [ |
| 107 | + jnp.concatenate([q_tp_shards[t], k_tp_shards[t], v_tp_shards[t]], axis=0) for t in range(tp_size) |
| 108 | + ] |
| 109 | + |
| 110 | + self.vllm_state[f"{prefix}.self_attn.qkv_proj.weight"] = jnp.concatenate(tp_interleaved, axis=0) |
| 111 | + self.vllm_state[f"{prefix}.self_attn.o_proj.weight"] = jnp.transpose(o_layers[rep], (1, 0)) |
| 112 | + self.vllm_state[f"{prefix}.self_attn.q_norm.weight"] = qnorm_layers[rep] |
| 113 | + self.vllm_state[f"{prefix}.self_attn.k_norm.weight"] = knorm_layers[rep] |
| 114 | + |
| 115 | + else: |
| 116 | + gdn = slot_data["attention"] |
| 117 | + qkvz_layers = jnp.unstack(gdn["in_proj_qkvz"]["kernel"], axis=1) |
| 118 | + ba_layers = jnp.unstack(gdn["in_proj_ba"]["kernel"], axis=1) |
| 119 | + out_layers = jnp.unstack(gdn["out_proj"]["kernel"], axis=1) |
| 120 | + conv_layers = jnp.unstack(gdn["conv1d"]["kernel"], axis=1) |
| 121 | + |
| 122 | + A_log_layers = jnp.unstack(gdn["A_log"], axis=1) |
| 123 | + dt_bias_layers = jnp.unstack(gdn["dt_bias"], axis=1) |
| 124 | + gdn_norm_layers = _unstack_rep(gdn["norm"]["rms_norm"]["scale"]) |
| 125 | + |
| 126 | + for rep in range(self.num_reps): |
| 127 | + i = rep * self.NUM_SLOTS + slot |
| 128 | + prefix = f"vllm_model.language_model.model.layers.{i}" |
| 129 | + |
| 130 | + self.vllm_state[f"{prefix}.input_layernorm.weight"] = pre_ln[rep] |
| 131 | + self.vllm_state[f"{prefix}.post_attention_layernorm.weight"] = post_ln[rep] |
| 132 | + |
| 133 | + # Extract MaxText GDN QKVZ Layout |
| 134 | + H_k, H_v, D_k, D_v, V_per_K = 16, 32, 128, 128, 2 |
| 135 | + |
| 136 | + t_m = jnp.transpose(qkvz_layers[rep], (1, 0)) |
| 137 | + block_size = D_k + D_k + V_per_K * D_v + V_per_K * D_v |
| 138 | + t_r = t_m.reshape(H_k, block_size, -1) |
| 139 | + |
| 140 | + q = t_r[:, :D_k, :].reshape(H_k * D_k, -1) |
| 141 | + k = t_r[:, D_k : 2 * D_k, :].reshape(H_k * D_k, -1) |
| 142 | + v = t_r[:, 2 * D_k : 2 * D_k + V_per_K * D_v, :].reshape(H_v * D_v, -1) |
| 143 | + z = t_r[:, 2 * D_k + V_per_K * D_v :, :].reshape(H_v * D_v, -1) |
| 144 | + |
| 145 | + tp_size = self.vllm_tp |
| 146 | + q_shards = jnp.split(q, tp_size, axis=0) |
| 147 | + k_shards = jnp.split(k, tp_size, axis=0) |
| 148 | + v_shards = jnp.split(v, tp_size, axis=0) |
| 149 | + z_shards = jnp.split(z, tp_size, axis=0) |
| 150 | + |
| 151 | + qkvz_interleaved = [ |
| 152 | + jnp.concatenate([q_shards[s], k_shards[s], v_shards[s], z_shards[s]], axis=0) for s in range(tp_size) |
| 153 | + ] |
| 154 | + self.vllm_state[f"{prefix}.linear_attn.in_proj_qkvz.weight"] = jnp.concatenate(qkvz_interleaved, axis=0) |
| 155 | + |
| 156 | + # Extract MaxText GDN BA Layout |
| 157 | + t_m_ba = jnp.transpose(ba_layers[rep], (1, 0)) |
| 158 | + block_size_ba = V_per_K * 2 |
| 159 | + t_r_ba = t_m_ba.reshape(H_k, block_size_ba, -1) |
| 160 | + |
| 161 | + b = t_r_ba[:, :V_per_K, :].reshape(H_v, -1) |
| 162 | + a = t_r_ba[:, V_per_K:, :].reshape(H_v, -1) |
| 163 | + |
| 164 | + b_shards = jnp.split(b, tp_size, axis=0) |
| 165 | + a_shards = jnp.split(a, tp_size, axis=0) |
| 166 | + |
| 167 | + ba_interleaved = [jnp.concatenate([b_shards[s], a_shards[s]], axis=0) for s in range(tp_size)] |
| 168 | + self.vllm_state[f"{prefix}.linear_attn.in_proj_ba.weight"] = jnp.concatenate(ba_interleaved, axis=0) |
| 169 | + |
| 170 | + self.vllm_state[f"{prefix}.linear_attn.out_proj.weight"] = jnp.transpose(out_layers[rep], (1, 0)) |
| 171 | + self.vllm_state[f"{prefix}.linear_attn.conv1d.weight"] = jnp.transpose(conv_layers[rep], (2, 1, 0)) |
| 172 | + self.vllm_state[f"{prefix}.linear_attn.A_log"] = A_log_layers[rep] |
| 173 | + self.vllm_state[f"{prefix}.linear_attn.dt_bias"] = dt_bias_layers[rep] |
| 174 | + self.vllm_state[f"{prefix}.linear_attn.norm.weight"] = gdn_norm_layers[rep] |
| 175 | + |
| 176 | + gc.collect() |
| 177 | + |
| 178 | + def _convert_moe(self, params): |
| 179 | + decoder = params["base"]["decoder"] |
| 180 | + blocks = decoder.get("scanned_blocks", decoder.get("layers")) |
| 181 | + slot_prefix = "layers" if "scanned_blocks" in decoder else "layer" |
| 182 | + |
| 183 | + for slot in range(self.NUM_SLOTS): |
| 184 | + slot_data = blocks[f"{slot_prefix}_{slot}"] |
| 185 | + |
| 186 | + if "mlp" not in slot_data or "routed_experts" not in slot_data["mlp"]: |
| 187 | + continue |
| 188 | + |
| 189 | + mlp_block = slot_data["mlp"] |
| 190 | + routed = mlp_block["routed_experts"] |
| 191 | + has_shared = "shared_expert" in mlp_block |
| 192 | + |
| 193 | + router_weights = jnp.unstack(jnp.transpose(routed["gate"]["kernel"], (1, 2, 0)), axis=0) |
| 194 | + |
| 195 | + # ------------------------------------------------------------- |
| 196 | + # Fusing, TP Interleaving, and TPU GMM Alignment for W1 and W3 |
| 197 | + # ------------------------------------------------------------- |
| 198 | + wi_0 = jnp.transpose(routed["wi_0"], (1, 0, 2, 3)) |
| 199 | + wi_1 = jnp.transpose(routed["wi_1"], (1, 0, 2, 3)) |
| 200 | + |
| 201 | + num_reps, num_experts, d_model, d_inner = wi_0.shape |
| 202 | + tp_size = self.vllm_tp |
| 203 | + |
| 204 | + # vLLM's TPU Grouped GEMM kernel requires 128-alignment per expert chunk |
| 205 | + chunk_size = d_inner // tp_size |
| 206 | + padded_chunk_size = ((chunk_size + 127) // 128) * 128 |
| 207 | + pad_amount = padded_chunk_size - chunk_size |
| 208 | + |
| 209 | + w1_chunks = wi_0.reshape(num_reps, num_experts, d_model, tp_size, chunk_size) |
| 210 | + w3_chunks = wi_1.reshape(num_reps, num_experts, d_model, tp_size, chunk_size) |
| 211 | + |
| 212 | + # Apply padding if running on a topology that splinters chunks below 128 (e.g. TP=8) |
| 213 | + if pad_amount > 0: |
| 214 | + w1_chunks = jnp.pad(w1_chunks, ((0, 0), (0, 0), (0, 0), (0, 0), (0, pad_amount))) |
| 215 | + w3_chunks = jnp.pad(w3_chunks, ((0, 0), (0, 0), (0, 0), (0, 0), (0, pad_amount))) |
| 216 | + |
| 217 | + # Interleave W1 and W3 shards -> Shape: (reps, exp, d_model, tp, 2, padded_chunk) |
| 218 | + combined_shards = jnp.stack([w1_chunks, w3_chunks], axis=-2) |
| 219 | + |
| 220 | + # Flatten the TP, 2, and chunk dimensions back into the final inner dimension |
| 221 | + gate_up = combined_shards.reshape(num_reps, num_experts, d_model, -1) |
| 222 | + w13_layers = jnp.unstack(gate_up, axis=0) |
| 223 | + # ------------------------------------------------------------- |
| 224 | + |
| 225 | + wo_transposed = jnp.transpose(routed["wo"], (1, 0, 2, 3)) |
| 226 | + down_layers = jnp.unstack(wo_transposed, axis=0) |
| 227 | + |
| 228 | + # Extract Shared Experts |
| 229 | + if has_shared: |
| 230 | + shared = mlp_block["shared_expert"] |
| 231 | + sh_gate_layers = jnp.unstack(jnp.transpose(shared["wi_0"]["kernel"], (1, 2, 0)), axis=0) |
| 232 | + sh_up_layers = jnp.unstack(jnp.transpose(shared["wi_1"]["kernel"], (1, 2, 0)), axis=0) |
| 233 | + sh_down_layers = jnp.unstack(jnp.transpose(shared["wo"]["kernel"], (1, 2, 0)), axis=0) |
| 234 | + |
| 235 | + if "shared_expert_gate" in mlp_block: |
| 236 | + sh_gate_router_layers = jnp.unstack(jnp.transpose(mlp_block["shared_expert_gate"]["kernel"], (1, 2, 0)), axis=0) |
| 237 | + |
| 238 | + for rep in range(self.num_reps): |
| 239 | + i = rep * self.NUM_SLOTS + slot |
| 240 | + p = f"vllm_model.language_model.model.layers.{i}" |
| 241 | + |
| 242 | + self.vllm_state[f"{p}.mlp.gate.weight"] = router_weights[rep] |
| 243 | + self.vllm_state[f"{p}.mlp.experts.w13_weight"] = w13_layers[rep] |
| 244 | + self.vllm_state[f"{p}.mlp.experts.w2_weight"] = down_layers[rep] |
| 245 | + |
| 246 | + if has_shared: |
| 247 | + sh_g, sh_u = sh_gate_layers[rep], sh_up_layers[rep] |
| 248 | + sh_per_tp = sh_g.shape[0] // self.vllm_tp |
| 249 | + |
| 250 | + shared_gate_up = jnp.concatenate( |
| 251 | + [ |
| 252 | + sh_g.reshape(self.vllm_tp, sh_per_tp, sh_g.shape[1]), |
| 253 | + sh_u.reshape(self.vllm_tp, sh_per_tp, sh_u.shape[1]), |
| 254 | + ], |
| 255 | + axis=1, |
| 256 | + ).reshape(-1, sh_g.shape[1]) |
| 257 | + |
| 258 | + self.vllm_state[f"{p}.mlp.shared_expert.gate_up_proj.weight"] = shared_gate_up |
| 259 | + self.vllm_state[f"{p}.mlp.shared_expert.down_proj.weight"] = sh_down_layers[rep] |
| 260 | + |
| 261 | + if "shared_expert_gate" in mlp_block: |
| 262 | + self.vllm_state[f"{p}.mlp.shared_expert_gate.weight"] = sh_gate_router_layers[rep] |
| 263 | + |
| 264 | + gc.collect() |
0 commit comments