Skip to content

Commit ce3adf3

Browse files
initial commit for qwen3.5 moe weight sync script
fix weight sync script Onboard model to validator script and rollout file fix param mappings and converter logic to get weight conversion working Working script for weight sync Ran linter Ran linter Add max_num_batched_tokens only for qwen3.5
1 parent ff05b79 commit ce3adf3

3 files changed

Lines changed: 323 additions & 10 deletions

File tree

src/maxtext/integration/vllm/maxtext_vllm_rollout.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,15 @@
3535
from tunix.rl.rollout import base_rollout, vllm_rollout
3636

3737
from maxtext.integration.vllm.torchax_converter.qwen3_moe import Qwen3MaxTextToVLLMConverter
38+
from maxtext.integration.vllm.torchax_converter.qwen35_moe import Qwen35MaxTextToVLLMConverter
3839

3940

4041
def _create_model_converter(model_name: str, config: Any, mesh: jax.sharding.Mesh):
4142
"""Instantiate the converter for a MaxText model name."""
4243
if model_name in {"qwen3-30b-a3b", "qwen3-30b-a3b-base", "qwen3-235b-a22b"}:
4344
return Qwen3MaxTextToVLLMConverter(config=config, mesh=mesh)
45+
elif model_name in {"qwen3.5-35b-a3b"}:
46+
return Qwen35MaxTextToVLLMConverter(config=config, mesh=mesh)
4447

4548
raise ValueError(f"No MaxText->vLLM converter registered for model {model_name!r}.")
4649

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Qwen 3.5 MaxText to vLLM Converter (Supports 35B MoE Hybrid Architecture)."""
16+
17+
import gc
18+
import logging
19+
import jax
20+
import jax.numpy as jnp
21+
22+
from maxtext.integration.vllm.torchax_converter.base import BaseMaxTextToVLLMConverter, timer, GREEN, RESET
23+
24+
25+
class Qwen35MaxTextToVLLMConverter(BaseMaxTextToVLLMConverter):
26+
"""Converts MaxText Qwen3.5 (Scanned Block) layout to vLLM execution layout."""
27+
28+
NUM_SLOTS = 4 # 3 GDN layers + 1 Full Attention layer per cycle
29+
30+
def convert(self, model_state: dict):
31+
"""Main entry point for the Tunix weight synchronization."""
32+
logging.info("\n%sStarting Qwen 3.5 Conversion (Hybrid 3:1 MoE)...%s", GREEN, RESET)
33+
self.vllm_state = {}
34+
35+
self.num_reps = self.num_layers // self.NUM_SLOTS
36+
37+
with timer("Convert Global Weights"):
38+
self._convert_global(model_state)
39+
40+
with timer("Convert Hybrid Attention Weights"):
41+
self._convert_attn(model_state)
42+
43+
with timer("Convert MoE Weights"):
44+
self._convert_moe(model_state)
45+
46+
# ------------------------------------------------------------------ #
47+
# Protect JAX compilation
48+
# ------------------------------------------------------------------ #
49+
for key in self.vllm_state:
50+
self.vllm_state[key] = self.vllm_state[key].astype(jnp.bfloat16)
51+
52+
return self.vllm_state
53+
54+
# ------------------------------------------------------------------ #
55+
# 1. Global Weights
56+
# ------------------------------------------------------------------ #
57+
def _convert_global(self, params):
58+
logging.info("_convert_global: Processing embeddings and LM head...")
59+
60+
self.vllm_state["vllm_model.language_model.model.embed_tokens.weight"] = jnp.array(
61+
params["base"]["token_embedder"]["embedding"]
62+
)
63+
64+
self.vllm_state["vllm_model.language_model.model.norm.weight"] = jnp.array(
65+
params["base"]["decoder"]["decoder_norm"]["scale"]
66+
)
67+
68+
self.vllm_state["vllm_model.language_model.lm_head.weight"] = jnp.transpose(
69+
params["base"]["decoder"]["logits_dense"]["kernel"], (1, 0)
70+
)
71+
72+
# ------------------------------------------------------------------ #
73+
# 2. Hybrid Attention (Scanned 3:1 Blocks)
74+
# ------------------------------------------------------------------ #
75+
def _convert_attn(self, params):
76+
logging.info("_convert_attn: Unstacking layer norms and routing hybrid attention...")
77+
decoder = params["base"]["decoder"]
78+
79+
if "scanned_blocks" in decoder:
80+
blocks = decoder["scanned_blocks"]
81+
slot_prefix = "layers"
82+
else:
83+
blocks = decoder["layers"]
84+
slot_prefix = "layer"
85+
86+
@jax.jit
87+
def _unstack_rep(x):
88+
return jnp.unstack(x, axis=1)
89+
90+
for slot in range(self.NUM_SLOTS):
91+
is_full_attention = slot == 3
92+
slot_data = blocks[f"{slot_prefix}_{slot}"]
93+
94+
pre_ln = _unstack_rep(slot_data["input_layernorm"]["scale"])
95+
post_ln = _unstack_rep(slot_data["post_attention_layernorm"]["scale"])
96+
97+
if is_full_attention:
98+
attn = slot_data["attention"]["attention"]
99+
100+
q_layers = jnp.unstack(jnp.transpose(attn["query"]["kernel"], (1, 0, 2, 3)), axis=0)
101+
k_layers = jnp.unstack(jnp.transpose(attn["key"]["kernel"], (1, 0, 2, 3)), axis=0)
102+
v_layers = jnp.unstack(jnp.transpose(attn["value"]["kernel"], (1, 0, 2, 3)), axis=0)
103+
o_layers = jnp.unstack(attn["out"]["kernel"], axis=1)
104+
105+
qnorm_layers = _unstack_rep(attn["query_norm"]["scale"])
106+
knorm_layers = _unstack_rep(attn["key_norm"]["scale"])
107+
108+
for rep in range(self.num_reps):
109+
i = rep * self.NUM_SLOTS + slot
110+
prefix = f"vllm_model.language_model.model.layers.{i}"
111+
112+
self.vllm_state[f"{prefix}.input_layernorm.weight"] = pre_ln[rep]
113+
self.vllm_state[f"{prefix}.post_attention_layernorm.weight"] = post_ln[rep]
114+
115+
q, k, v = q_layers[rep], k_layers[rep], v_layers[rep]
116+
117+
# Transpose to standard (num_heads, head_dim, emb_dim)
118+
q_T = jnp.transpose(q, (1, 2, 0))
119+
k_T = jnp.transpose(k, (1, 2, 0))
120+
v_T = jnp.transpose(v, (1, 2, 0))
121+
122+
# Flatten head dimensions and slice for TP interleaving
123+
tp_size = self.vllm_tp
124+
q_tp_shards = jnp.split(q_T.reshape(-1, q.shape[0]), tp_size, axis=0)
125+
k_tp_shards = jnp.split(k_T.reshape(-1, k.shape[0]), tp_size, axis=0)
126+
v_tp_shards = jnp.split(v_T.reshape(-1, v.shape[0]), tp_size, axis=0)
127+
128+
tp_interleaved = []
129+
for t in range(tp_size):
130+
tp_interleaved.append(jnp.concatenate([q_tp_shards[t], k_tp_shards[t], v_tp_shards[t]], axis=0))
131+
132+
self.vllm_state[f"{prefix}.self_attn.qkv_proj.weight"] = jnp.concatenate(tp_interleaved, axis=0)
133+
self.vllm_state[f"{prefix}.self_attn.o_proj.weight"] = jnp.transpose(o_layers[rep], (1, 0))
134+
self.vllm_state[f"{prefix}.self_attn.q_norm.weight"] = qnorm_layers[rep]
135+
self.vllm_state[f"{prefix}.self_attn.k_norm.weight"] = knorm_layers[rep]
136+
137+
else:
138+
gdn = slot_data["attention"]
139+
140+
qkvz_layers = jnp.unstack(gdn["in_proj_qkvz"]["kernel"], axis=1)
141+
ba_layers = jnp.unstack(gdn["in_proj_ba"]["kernel"], axis=1)
142+
out_layers = jnp.unstack(gdn["out_proj"]["kernel"], axis=1)
143+
144+
conv_layers = jnp.unstack(gdn["conv1d"]["kernel"], axis=1)
145+
146+
A_log_layers = jnp.unstack(gdn["A_log"], axis=1)
147+
dt_bias_layers = jnp.unstack(gdn["dt_bias"], axis=1)
148+
gdn_norm_layers = _unstack_rep(gdn["norm"]["rms_norm"]["scale"])
149+
150+
for rep in range(self.num_reps):
151+
i = rep * self.NUM_SLOTS + slot
152+
prefix = f"vllm_model.language_model.model.layers.{i}"
153+
154+
self.vllm_state[f"{prefix}.input_layernorm.weight"] = pre_ln[rep]
155+
self.vllm_state[f"{prefix}.post_attention_layernorm.weight"] = post_ln[rep]
156+
157+
# Extract MaxText QKVZ layout
158+
H_k = 16
159+
H_v = 32
160+
D_k = 128
161+
D_v = 128
162+
V_per_K = 2
163+
164+
t_m = jnp.transpose(qkvz_layers[rep], (1, 0))
165+
block_size = D_k + D_k + V_per_K * D_v + V_per_K * D_v
166+
t_r = t_m.reshape(H_k, block_size, -1)
167+
168+
q_r = t_r[:, :D_k, :]
169+
k_r = t_r[:, D_k : 2 * D_k, :]
170+
v_r = t_r[:, 2 * D_k : 2 * D_k + V_per_K * D_v, :]
171+
z_r = t_r[:, 2 * D_k + V_per_K * D_v :, :]
172+
173+
q = q_r.reshape(H_k * D_k, -1)
174+
k = k_r.reshape(H_k * D_k, -1)
175+
v = v_r.reshape(H_v * D_v, -1)
176+
z = z_r.reshape(H_v * D_v, -1)
177+
178+
# Interleave GDN QKVZ by Tensor Parallel shard
179+
tp_size = self.vllm_tp
180+
q_shards = jnp.split(q, tp_size, axis=0)
181+
k_shards = jnp.split(k, tp_size, axis=0)
182+
v_shards = jnp.split(v, tp_size, axis=0)
183+
z_shards = jnp.split(z, tp_size, axis=0)
184+
185+
qkvz_interleaved_shards = []
186+
for s in range(tp_size):
187+
qkvz_interleaved_shards.append(jnp.concatenate([q_shards[s], k_shards[s], v_shards[s], z_shards[s]], axis=0))
188+
189+
self.vllm_state[f"{prefix}.linear_attn.in_proj_qkvz.weight"] = jnp.concatenate(qkvz_interleaved_shards, axis=0)
190+
191+
# Extract MaxText BA layout
192+
t_m_ba = jnp.transpose(ba_layers[rep], (1, 0))
193+
block_size_ba = V_per_K * 2
194+
t_r_ba = t_m_ba.reshape(H_k, block_size_ba, -1)
195+
196+
b_r = t_r_ba[:, :V_per_K, :]
197+
a_r = t_r_ba[:, V_per_K:, :]
198+
199+
b = b_r.reshape(H_v, -1)
200+
a = a_r.reshape(H_v, -1)
201+
202+
# Interleave BA vectors by Tensor Parallel shard
203+
b_shards = jnp.split(b, tp_size, axis=0)
204+
a_shards = jnp.split(a, tp_size, axis=0)
205+
206+
ba_interleaved_shards = []
207+
for s in range(tp_size):
208+
ba_interleaved_shards.append(jnp.concatenate([b_shards[s], a_shards[s]], axis=0))
209+
210+
self.vllm_state[f"{prefix}.linear_attn.in_proj_ba.weight"] = jnp.concatenate(ba_interleaved_shards, axis=0)
211+
self.vllm_state[f"{prefix}.linear_attn.out_proj.weight"] = jnp.transpose(out_layers[rep], (1, 0))
212+
213+
# MT: [K, 1, C] <-> HF: [C, 1, K]
214+
conv_w = conv_layers[rep]
215+
self.vllm_state[f"{prefix}.linear_attn.conv1d.weight"] = jnp.transpose(conv_w, (2, 1, 0))
216+
self.vllm_state[f"{prefix}.linear_attn.A_log"] = A_log_layers[rep]
217+
self.vllm_state[f"{prefix}.linear_attn.dt_bias"] = dt_bias_layers[rep]
218+
self.vllm_state[f"{prefix}.linear_attn.norm.weight"] = gdn_norm_layers[rep]
219+
220+
gc.collect()
221+
222+
# ------------------------------------------------------------------ #
223+
# 3. Mixture of Experts (Scanned Block)
224+
# ------------------------------------------------------------------ #
225+
def _convert_moe(self, params):
226+
logging.info("_convert_moe: Packaging routed and shared experts...")
227+
decoder = params["base"]["decoder"]
228+
229+
if "scanned_blocks" in decoder:
230+
blocks = decoder["scanned_blocks"]
231+
slot_prefix = "layers"
232+
else:
233+
blocks = decoder["layers"]
234+
slot_prefix = "layer"
235+
236+
for slot in range(self.NUM_SLOTS):
237+
slot_data = blocks[f"{slot_prefix}_{slot}"]
238+
239+
if "mlp" not in slot_data or "routed_experts" not in slot_data["mlp"]:
240+
continue
241+
242+
mlp_block = slot_data["mlp"]
243+
routed = mlp_block["routed_experts"]
244+
has_shared = "shared_expert" in mlp_block
245+
246+
router_weights = jnp.unstack(jnp.transpose(routed["gate"]["kernel"], (1, 2, 0)), axis=0)
247+
248+
# Fusing and Tensor Parallel Interleaving for MoE W1 and W3
249+
wi_0 = jnp.transpose(routed["wi_0"], (1, 0, 2, 3))
250+
wi_1 = jnp.transpose(routed["wi_1"], (1, 0, 2, 3))
251+
252+
tp_size = self.vllm_tp
253+
w1_shards = jnp.split(wi_0, tp_size, axis=-1)
254+
w3_shards = jnp.split(wi_1, tp_size, axis=-1)
255+
256+
interleaved_shards = []
257+
for i in range(tp_size):
258+
interleaved_shards.append(w1_shards[i])
259+
interleaved_shards.append(w3_shards[i])
260+
261+
gate_up = jnp.concatenate(interleaved_shards, axis=-1)
262+
w13_layers = jnp.unstack(gate_up, axis=0)
263+
264+
wo_transposed = jnp.transpose(routed["wo"], (1, 0, 2, 3))
265+
down_layers = jnp.unstack(wo_transposed, axis=0)
266+
267+
# Extract Shared Experts
268+
if has_shared:
269+
shared = mlp_block["shared_expert"]
270+
sh_gate_layers = jnp.unstack(jnp.transpose(shared["wi_0"]["kernel"], (1, 2, 0)), axis=0)
271+
sh_up_layers = jnp.unstack(jnp.transpose(shared["wi_1"]["kernel"], (1, 2, 0)), axis=0)
272+
sh_down_layers = jnp.unstack(jnp.transpose(shared["wo"]["kernel"], (1, 2, 0)), axis=0)
273+
274+
if "shared_expert_gate" in mlp_block:
275+
sh_gate_router_layers = jnp.unstack(jnp.transpose(mlp_block["shared_expert_gate"]["kernel"], (1, 2, 0)), axis=0)
276+
277+
for rep in range(self.num_reps):
278+
i = rep * self.NUM_SLOTS + slot
279+
p = f"vllm_model.language_model.model.layers.{i}"
280+
281+
self.vllm_state[f"{p}.mlp.gate.weight"] = router_weights[rep]
282+
self.vllm_state[f"{p}.mlp.experts.w13_weight"] = w13_layers[rep]
283+
self.vllm_state[f"{p}.mlp.experts.w2_weight"] = down_layers[rep]
284+
285+
# Build Shared Expert structure
286+
if has_shared:
287+
sh_g, sh_u = sh_gate_layers[rep], sh_up_layers[rep]
288+
sh_per_tp = sh_g.shape[0] // self.vllm_tp
289+
290+
shared_gate_up = jnp.concatenate(
291+
[
292+
sh_g.reshape(self.vllm_tp, sh_per_tp, sh_g.shape[1]),
293+
sh_u.reshape(self.vllm_tp, sh_per_tp, sh_u.shape[1]),
294+
],
295+
axis=1,
296+
).reshape(-1, sh_g.shape[1])
297+
298+
self.vllm_state[f"{p}.mlp.shared_expert.gate_up_proj.weight"] = shared_gate_up
299+
self.vllm_state[f"{p}.mlp.shared_expert.down_proj.weight"] = sh_down_layers[rep]
300+
301+
if "shared_expert_gate" in mlp_block:
302+
self.vllm_state[f"{p}.mlp.shared_expert_gate.weight"] = sh_gate_router_layers[rep]
303+
304+
gc.collect()

src/maxtext/integration/vllm/torchax_converter/validate_converter.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
3. loads the matching vLLM model, and
2121
4. assigns the converted weights before running a short generation check.
2222
23-
python -m maxtext.integration.vllm.torchax_converter.validate_converter \
24-
src/maxtext/configs/post_train/rl.yml model_name=qwen3-30b-a3b \
25-
tokenizer_type=huggingface tokenizer_path=Qwen/Qwen3-30B-A3B \
26-
load_parameters_path=<your_maxtext_checkpoint_path> run_name=qwen3_converter_validation \
27-
per_device_batch_size=1 max_prefill_predict_length=8 max_target_length=16 steps=1 \
28-
scan_layers=true skip_jax_distributed_system=true weight_dtype=bfloat16 \
29-
rollout_tensor_parallelism=4 hbm_utilization_vllm=0.6 async_scheduling=false \
30-
prompt="Paris is" hf_access_token=<token> use_chat_template=true
23+
python -m maxtext.integration.vllm.torchax_converter.validate_converter \
24+
src/maxtext/configs/post_train/rl.yml model_name=qwen3-30b-a3b \
25+
tokenizer_type=huggingface tokenizer_path=Qwen/Qwen3-30B-A3B \
26+
load_parameters_path=<your_maxtext_checkpoint_path> run_name=qwen3_converter_validation \
27+
per_device_batch_size=1 max_prefill_predict_length=8 max_target_length=16 steps=1 \
28+
scan_layers=true skip_jax_distributed_system=true weight_dtype=bfloat16 \
29+
rollout_tensor_parallelism=4 hbm_utilization_vllm=0.6 async_scheduling=false \
30+
prompt="Paris is" hf_access_token=<token> use_chat_template=true
3131
For multislice (e.g. 2x128-device slices), additionally pass:
3232
num_trainer_slices=1 num_samplers_slices=1
3333
@@ -70,6 +70,7 @@
7070
from maxtext.integration.vllm.torchax_converter.base import timer
7171
from maxtext.integration.vllm.torchax_converter.gemma4_moe import Gemma4MaxTextToVLLMConverter
7272
from maxtext.integration.vllm.torchax_converter.qwen3_moe import Qwen3MaxTextToVLLMConverter
73+
from maxtext.integration.vllm.torchax_converter.qwen35_moe import Qwen35MaxTextToVLLMConverter
7374
from maxtext.utils import model_creation_utils
7475

7576
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
@@ -81,6 +82,7 @@
8182
"qwen3-30b-a3b-base": "Qwen/Qwen3-30B-A3B",
8283
"qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B",
8384
"gemma4-26b": "google/gemma-4-26B-A4B",
85+
"qwen3.5-35b-a3b": "Qwen/Qwen3.5-35B-A3B",
8486
# Add more mappings as needed
8587
}
8688

@@ -150,8 +152,6 @@ def _log_weight_stats(converted_state: dict, vllm_state: dict, compare: bool) ->
150152
conv = np.array(weight_array, dtype=np.float32)
151153
# rel_frobenius = ||converted - ref||_F / ||ref||_F.
152154
# ~0 means bit-for-bit correct; ~1 or above means the content is wrong.
153-
# Unlike mean/std/min/max, this catches permutation and transposition bugs
154-
# because it is order-sensitive.
155155
rel_frob = float(np.linalg.norm(conv - ref)) / (float(np.linalg.norm(ref)) + 1e-8)
156156
logging.info(" [VLLM-REF] %s | %s", key, _weight_stats_str(vllm_state[key]))
157157
logging.info(" [DIFF] %s | rel_frobenius=%.6f", key, rel_frob)
@@ -289,6 +289,8 @@ def validate_converter(argv) -> None:
289289

290290
if trainer_config.model_name.startswith("gemma4"):
291291
converter = Gemma4MaxTextToVLLMConverter(trainer_config, mesh)
292+
elif trainer_config.model_name.startswith("qwen3.5"):
293+
converter = Qwen35MaxTextToVLLMConverter(trainer_config, mesh)
292294
else:
293295
converter = Qwen3MaxTextToVLLMConverter(trainer_config, mesh)
294296
with timer("Overall Conversion"):
@@ -317,6 +319,10 @@ def validate_converter(argv) -> None:
317319
"gpu_memory_utilization": getattr(sampler_config, "hbm_utilization_vllm", 0.5),
318320
"async_scheduling": getattr(sampler_config, "async_scheduling", False),
319321
}
322+
# Conditionally add max_num_batched_tokens only for qwen3.5
323+
if trainer_config.model_name == "qwen3.5-35b-a3b":
324+
vllm_kwargs["max_num_batched_tokens"] = 16384
325+
320326
if multislice:
321327
# Pin vLLM to its assigned sampler devices so it doesn't overlap with trainer.
322328
vllm_kwargs["additional_config"] = {

0 commit comments

Comments
 (0)