Skip to content

Commit eed71e0

Browse files
initial commit for qwen3.5 moe weight sync script
fix weight sync script Onboard model to validator script and rollout file fix param mappings and converter logic to get weight conversion working Working script for weight sync Ran linter Ran linter Add max_num_batched_tokens only for qwen3.5 Ran linter
1 parent ff05b79 commit eed71e0

3 files changed

Lines changed: 283 additions & 10 deletions

File tree

src/maxtext/integration/vllm/maxtext_vllm_rollout.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,15 @@
3535
from tunix.rl.rollout import base_rollout, vllm_rollout
3636

3737
from maxtext.integration.vllm.torchax_converter.qwen3_moe import Qwen3MaxTextToVLLMConverter
38+
from maxtext.integration.vllm.torchax_converter.qwen35_moe import Qwen35MaxTextToVLLMConverter
3839

3940

4041
def _create_model_converter(model_name: str, config: Any, mesh: jax.sharding.Mesh):
4142
"""Instantiate the converter for a MaxText model name."""
4243
if model_name in {"qwen3-30b-a3b", "qwen3-30b-a3b-base", "qwen3-235b-a22b"}:
4344
return Qwen3MaxTextToVLLMConverter(config=config, mesh=mesh)
45+
elif model_name in {"qwen3.5-35b-a3b"}:
46+
return Qwen35MaxTextToVLLMConverter(config=config, mesh=mesh)
4447

4548
raise ValueError(f"No MaxText->vLLM converter registered for model {model_name!r}.")
4649

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Qwen 3.5 MaxText to vLLM Converter (Supports 35B MoE Hybrid Architecture)."""
16+
17+
import gc
18+
import logging
19+
import jax
20+
import jax.numpy as jnp
21+
22+
from maxtext.integration.vllm.torchax_converter.base import BaseMaxTextToVLLMConverter, timer, GREEN, RESET
23+
24+
25+
class Qwen35MaxTextToVLLMConverter(BaseMaxTextToVLLMConverter):
26+
"""Converts MaxText Qwen3.5 (Scanned Block) layout to vLLM execution layout."""
27+
28+
NUM_SLOTS = 4 # 3 GDN layers + 1 Full Attention layer per cycle
29+
30+
def convert(self, model_state: dict):
31+
logging.info("\n%sStarting Qwen 3.5 Conversion (Hybrid 3:1 MoE)...%s", GREEN, RESET)
32+
self.vllm_state = {}
33+
self.num_reps = self.num_layers // self.NUM_SLOTS
34+
35+
with timer("Convert Global Weights"):
36+
self._convert_global(model_state)
37+
38+
with timer("Convert Hybrid Attention Weights"):
39+
self._convert_attn(model_state)
40+
41+
with timer("Convert MoE Weights"):
42+
self._convert_moe(model_state)
43+
44+
# Protect JAX compilation by enforcing bfloat16
45+
for key in self.vllm_state:
46+
self.vllm_state[key] = self.vllm_state[key].astype(jnp.bfloat16)
47+
48+
return self.vllm_state
49+
50+
def _convert_global(self, params):
51+
self.vllm_state["vllm_model.language_model.model.embed_tokens.weight"] = jnp.array(
52+
params["base"]["token_embedder"]["embedding"]
53+
)
54+
self.vllm_state["vllm_model.language_model.model.norm.weight"] = jnp.array(
55+
params["base"]["decoder"]["decoder_norm"]["scale"]
56+
)
57+
self.vllm_state["vllm_model.language_model.lm_head.weight"] = jnp.transpose(
58+
params["base"]["decoder"]["logits_dense"]["kernel"], (1, 0)
59+
)
60+
61+
def _convert_attn(self, params):
62+
decoder = params["base"]["decoder"]
63+
blocks = decoder.get("scanned_blocks", decoder.get("layers"))
64+
slot_prefix = "layers" if "scanned_blocks" in decoder else "layer"
65+
66+
@jax.jit
67+
def _unstack_rep(x):
68+
return jnp.unstack(x, axis=1)
69+
70+
for slot in range(self.NUM_SLOTS):
71+
is_full_attention = slot == 3
72+
slot_data = blocks[f"{slot_prefix}_{slot}"]
73+
74+
pre_ln = _unstack_rep(slot_data["input_layernorm"]["scale"])
75+
post_ln = _unstack_rep(slot_data["post_attention_layernorm"]["scale"])
76+
77+
if is_full_attention:
78+
attn = slot_data["attention"]["attention"]
79+
80+
q_layers = jnp.unstack(jnp.transpose(attn["query"]["kernel"], (1, 0, 2, 3)), axis=0)
81+
k_layers = jnp.unstack(jnp.transpose(attn["key"]["kernel"], (1, 0, 2, 3)), axis=0)
82+
v_layers = jnp.unstack(jnp.transpose(attn["value"]["kernel"], (1, 0, 2, 3)), axis=0)
83+
o_layers = jnp.unstack(attn["out"]["kernel"], axis=1)
84+
85+
qnorm_layers = _unstack_rep(attn["query_norm"]["scale"])
86+
knorm_layers = _unstack_rep(attn["key_norm"]["scale"])
87+
88+
for rep in range(self.num_reps):
89+
i = rep * self.NUM_SLOTS + slot
90+
prefix = f"vllm_model.language_model.model.layers.{i}"
91+
92+
self.vllm_state[f"{prefix}.input_layernorm.weight"] = pre_ln[rep]
93+
self.vllm_state[f"{prefix}.post_attention_layernorm.weight"] = post_ln[rep]
94+
95+
q, k, v = q_layers[rep], k_layers[rep], v_layers[rep]
96+
97+
q_T = jnp.transpose(q, (1, 2, 0))
98+
k_T = jnp.transpose(k, (1, 2, 0))
99+
v_T = jnp.transpose(v, (1, 2, 0))
100+
101+
tp_size = self.vllm_tp
102+
q_tp_shards = jnp.split(q_T.reshape(-1, q.shape[0]), tp_size, axis=0)
103+
k_tp_shards = jnp.split(k_T.reshape(-1, k.shape[0]), tp_size, axis=0)
104+
v_tp_shards = jnp.split(v_T.reshape(-1, v.shape[0]), tp_size, axis=0)
105+
106+
tp_interleaved = [
107+
jnp.concatenate([q_tp_shards[t], k_tp_shards[t], v_tp_shards[t]], axis=0) for t in range(tp_size)
108+
]
109+
110+
self.vllm_state[f"{prefix}.self_attn.qkv_proj.weight"] = jnp.concatenate(tp_interleaved, axis=0)
111+
self.vllm_state[f"{prefix}.self_attn.o_proj.weight"] = jnp.transpose(o_layers[rep], (1, 0))
112+
self.vllm_state[f"{prefix}.self_attn.q_norm.weight"] = qnorm_layers[rep]
113+
self.vllm_state[f"{prefix}.self_attn.k_norm.weight"] = knorm_layers[rep]
114+
115+
else:
116+
gdn = slot_data["attention"]
117+
qkvz_layers = jnp.unstack(gdn["in_proj_qkvz"]["kernel"], axis=1)
118+
ba_layers = jnp.unstack(gdn["in_proj_ba"]["kernel"], axis=1)
119+
out_layers = jnp.unstack(gdn["out_proj"]["kernel"], axis=1)
120+
conv_layers = jnp.unstack(gdn["conv1d"]["kernel"], axis=1)
121+
122+
A_log_layers = jnp.unstack(gdn["A_log"], axis=1)
123+
dt_bias_layers = jnp.unstack(gdn["dt_bias"], axis=1)
124+
gdn_norm_layers = _unstack_rep(gdn["norm"]["rms_norm"]["scale"])
125+
126+
for rep in range(self.num_reps):
127+
i = rep * self.NUM_SLOTS + slot
128+
prefix = f"vllm_model.language_model.model.layers.{i}"
129+
130+
self.vllm_state[f"{prefix}.input_layernorm.weight"] = pre_ln[rep]
131+
self.vllm_state[f"{prefix}.post_attention_layernorm.weight"] = post_ln[rep]
132+
133+
# Extract MaxText GDN QKVZ Layout
134+
H_k, H_v, D_k, D_v, V_per_K = 16, 32, 128, 128, 2
135+
136+
t_m = jnp.transpose(qkvz_layers[rep], (1, 0))
137+
block_size = D_k + D_k + V_per_K * D_v + V_per_K * D_v
138+
t_r = t_m.reshape(H_k, block_size, -1)
139+
140+
q = t_r[:, :D_k, :].reshape(H_k * D_k, -1)
141+
k = t_r[:, D_k : 2 * D_k, :].reshape(H_k * D_k, -1)
142+
v = t_r[:, 2 * D_k : 2 * D_k + V_per_K * D_v, :].reshape(H_v * D_v, -1)
143+
z = t_r[:, 2 * D_k + V_per_K * D_v :, :].reshape(H_v * D_v, -1)
144+
145+
tp_size = self.vllm_tp
146+
q_shards = jnp.split(q, tp_size, axis=0)
147+
k_shards = jnp.split(k, tp_size, axis=0)
148+
v_shards = jnp.split(v, tp_size, axis=0)
149+
z_shards = jnp.split(z, tp_size, axis=0)
150+
151+
qkvz_interleaved = [
152+
jnp.concatenate([q_shards[s], k_shards[s], v_shards[s], z_shards[s]], axis=0) for s in range(tp_size)
153+
]
154+
self.vllm_state[f"{prefix}.linear_attn.in_proj_qkvz.weight"] = jnp.concatenate(qkvz_interleaved, axis=0)
155+
156+
# Extract MaxText GDN BA Layout
157+
t_m_ba = jnp.transpose(ba_layers[rep], (1, 0))
158+
block_size_ba = V_per_K * 2
159+
t_r_ba = t_m_ba.reshape(H_k, block_size_ba, -1)
160+
161+
b = t_r_ba[:, :V_per_K, :].reshape(H_v, -1)
162+
a = t_r_ba[:, V_per_K:, :].reshape(H_v, -1)
163+
164+
b_shards = jnp.split(b, tp_size, axis=0)
165+
a_shards = jnp.split(a, tp_size, axis=0)
166+
167+
ba_interleaved = [jnp.concatenate([b_shards[s], a_shards[s]], axis=0) for s in range(tp_size)]
168+
self.vllm_state[f"{prefix}.linear_attn.in_proj_ba.weight"] = jnp.concatenate(ba_interleaved, axis=0)
169+
170+
self.vllm_state[f"{prefix}.linear_attn.out_proj.weight"] = jnp.transpose(out_layers[rep], (1, 0))
171+
self.vllm_state[f"{prefix}.linear_attn.conv1d.weight"] = jnp.transpose(conv_layers[rep], (2, 1, 0))
172+
self.vllm_state[f"{prefix}.linear_attn.A_log"] = A_log_layers[rep]
173+
self.vllm_state[f"{prefix}.linear_attn.dt_bias"] = dt_bias_layers[rep]
174+
self.vllm_state[f"{prefix}.linear_attn.norm.weight"] = gdn_norm_layers[rep]
175+
176+
gc.collect()
177+
178+
def _convert_moe(self, params):
179+
decoder = params["base"]["decoder"]
180+
blocks = decoder.get("scanned_blocks", decoder.get("layers"))
181+
slot_prefix = "layers" if "scanned_blocks" in decoder else "layer"
182+
183+
for slot in range(self.NUM_SLOTS):
184+
slot_data = blocks[f"{slot_prefix}_{slot}"]
185+
186+
if "mlp" not in slot_data or "routed_experts" not in slot_data["mlp"]:
187+
continue
188+
189+
mlp_block = slot_data["mlp"]
190+
routed = mlp_block["routed_experts"]
191+
has_shared = "shared_expert" in mlp_block
192+
193+
router_weights = jnp.unstack(jnp.transpose(routed["gate"]["kernel"], (1, 2, 0)), axis=0)
194+
195+
# -------------------------------------------------------------
196+
# Fusing, TP Interleaving, and TPU GMM Alignment for W1 and W3
197+
# -------------------------------------------------------------
198+
wi_0 = jnp.transpose(routed["wi_0"], (1, 0, 2, 3))
199+
wi_1 = jnp.transpose(routed["wi_1"], (1, 0, 2, 3))
200+
201+
num_reps, num_experts, d_model, d_inner = wi_0.shape
202+
tp_size = self.vllm_tp
203+
204+
# vLLM's TPU Grouped GEMM kernel requires 128-alignment per expert chunk
205+
chunk_size = d_inner // tp_size
206+
padded_chunk_size = ((chunk_size + 127) // 128) * 128
207+
pad_amount = padded_chunk_size - chunk_size
208+
209+
w1_chunks = wi_0.reshape(num_reps, num_experts, d_model, tp_size, chunk_size)
210+
w3_chunks = wi_1.reshape(num_reps, num_experts, d_model, tp_size, chunk_size)
211+
212+
# Apply padding if running on a topology that splinters chunks below 128 (e.g. TP=8)
213+
if pad_amount > 0:
214+
w1_chunks = jnp.pad(w1_chunks, ((0, 0), (0, 0), (0, 0), (0, 0), (0, pad_amount)))
215+
w3_chunks = jnp.pad(w3_chunks, ((0, 0), (0, 0), (0, 0), (0, 0), (0, pad_amount)))
216+
217+
# Interleave W1 and W3 shards -> Shape: (reps, exp, d_model, tp, 2, padded_chunk)
218+
combined_shards = jnp.stack([w1_chunks, w3_chunks], axis=-2)
219+
220+
# Flatten the TP, 2, and chunk dimensions back into the final inner dimension
221+
gate_up = combined_shards.reshape(num_reps, num_experts, d_model, -1)
222+
w13_layers = jnp.unstack(gate_up, axis=0)
223+
# -------------------------------------------------------------
224+
225+
wo_transposed = jnp.transpose(routed["wo"], (1, 0, 2, 3))
226+
down_layers = jnp.unstack(wo_transposed, axis=0)
227+
228+
# Extract Shared Experts
229+
if has_shared:
230+
shared = mlp_block["shared_expert"]
231+
sh_gate_layers = jnp.unstack(jnp.transpose(shared["wi_0"]["kernel"], (1, 2, 0)), axis=0)
232+
sh_up_layers = jnp.unstack(jnp.transpose(shared["wi_1"]["kernel"], (1, 2, 0)), axis=0)
233+
sh_down_layers = jnp.unstack(jnp.transpose(shared["wo"]["kernel"], (1, 2, 0)), axis=0)
234+
235+
if "shared_expert_gate" in mlp_block:
236+
sh_gate_router_layers = jnp.unstack(jnp.transpose(mlp_block["shared_expert_gate"]["kernel"], (1, 2, 0)), axis=0)
237+
238+
for rep in range(self.num_reps):
239+
i = rep * self.NUM_SLOTS + slot
240+
p = f"vllm_model.language_model.model.layers.{i}"
241+
242+
self.vllm_state[f"{p}.mlp.gate.weight"] = router_weights[rep]
243+
self.vllm_state[f"{p}.mlp.experts.w13_weight"] = w13_layers[rep]
244+
self.vllm_state[f"{p}.mlp.experts.w2_weight"] = down_layers[rep]
245+
246+
if has_shared:
247+
sh_g, sh_u = sh_gate_layers[rep], sh_up_layers[rep]
248+
sh_per_tp = sh_g.shape[0] // self.vllm_tp
249+
250+
shared_gate_up = jnp.concatenate(
251+
[
252+
sh_g.reshape(self.vllm_tp, sh_per_tp, sh_g.shape[1]),
253+
sh_u.reshape(self.vllm_tp, sh_per_tp, sh_u.shape[1]),
254+
],
255+
axis=1,
256+
).reshape(-1, sh_g.shape[1])
257+
258+
self.vllm_state[f"{p}.mlp.shared_expert.gate_up_proj.weight"] = shared_gate_up
259+
self.vllm_state[f"{p}.mlp.shared_expert.down_proj.weight"] = sh_down_layers[rep]
260+
261+
if "shared_expert_gate" in mlp_block:
262+
self.vllm_state[f"{p}.mlp.shared_expert_gate.weight"] = sh_gate_router_layers[rep]
263+
264+
gc.collect()

src/maxtext/integration/vllm/torchax_converter/validate_converter.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
3. loads the matching vLLM model, and
2121
4. assigns the converted weights before running a short generation check.
2222
23-
python -m maxtext.integration.vllm.torchax_converter.validate_converter \
24-
src/maxtext/configs/post_train/rl.yml model_name=qwen3-30b-a3b \
25-
tokenizer_type=huggingface tokenizer_path=Qwen/Qwen3-30B-A3B \
26-
load_parameters_path=<your_maxtext_checkpoint_path> run_name=qwen3_converter_validation \
27-
per_device_batch_size=1 max_prefill_predict_length=8 max_target_length=16 steps=1 \
28-
scan_layers=true skip_jax_distributed_system=true weight_dtype=bfloat16 \
29-
rollout_tensor_parallelism=4 hbm_utilization_vllm=0.6 async_scheduling=false \
30-
prompt="Paris is" hf_access_token=<token> use_chat_template=true
23+
python -m maxtext.integration.vllm.torchax_converter.validate_converter \
24+
src/maxtext/configs/post_train/rl.yml model_name=qwen3-30b-a3b \
25+
tokenizer_type=huggingface tokenizer_path=Qwen/Qwen3-30B-A3B \
26+
load_parameters_path=<your_maxtext_checkpoint_path> run_name=qwen3_converter_validation \
27+
per_device_batch_size=1 max_prefill_predict_length=8 max_target_length=16 steps=1 \
28+
scan_layers=true skip_jax_distributed_system=true weight_dtype=bfloat16 \
29+
rollout_tensor_parallelism=4 hbm_utilization_vllm=0.6 async_scheduling=false \
30+
prompt="Paris is" hf_access_token=<token> use_chat_template=true
3131
For multislice (e.g. 2x128-device slices), additionally pass:
3232
num_trainer_slices=1 num_samplers_slices=1
3333
@@ -70,6 +70,7 @@
7070
from maxtext.integration.vllm.torchax_converter.base import timer
7171
from maxtext.integration.vllm.torchax_converter.gemma4_moe import Gemma4MaxTextToVLLMConverter
7272
from maxtext.integration.vllm.torchax_converter.qwen3_moe import Qwen3MaxTextToVLLMConverter
73+
from maxtext.integration.vllm.torchax_converter.qwen35_moe import Qwen35MaxTextToVLLMConverter
7374
from maxtext.utils import model_creation_utils
7475

7576
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
@@ -81,6 +82,7 @@
8182
"qwen3-30b-a3b-base": "Qwen/Qwen3-30B-A3B",
8283
"qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B",
8384
"gemma4-26b": "google/gemma-4-26B-A4B",
85+
"qwen3.5-35b-a3b": "Qwen/Qwen3.5-35B-A3B",
8486
# Add more mappings as needed
8587
}
8688

@@ -150,8 +152,6 @@ def _log_weight_stats(converted_state: dict, vllm_state: dict, compare: bool) ->
150152
conv = np.array(weight_array, dtype=np.float32)
151153
# rel_frobenius = ||converted - ref||_F / ||ref||_F.
152154
# ~0 means bit-for-bit correct; ~1 or above means the content is wrong.
153-
# Unlike mean/std/min/max, this catches permutation and transposition bugs
154-
# because it is order-sensitive.
155155
rel_frob = float(np.linalg.norm(conv - ref)) / (float(np.linalg.norm(ref)) + 1e-8)
156156
logging.info(" [VLLM-REF] %s | %s", key, _weight_stats_str(vllm_state[key]))
157157
logging.info(" [DIFF] %s | rel_frobenius=%.6f", key, rel_frob)
@@ -289,6 +289,8 @@ def validate_converter(argv) -> None:
289289

290290
if trainer_config.model_name.startswith("gemma4"):
291291
converter = Gemma4MaxTextToVLLMConverter(trainer_config, mesh)
292+
elif trainer_config.model_name.startswith("qwen3.5"):
293+
converter = Qwen35MaxTextToVLLMConverter(trainer_config, mesh)
292294
else:
293295
converter = Qwen3MaxTextToVLLMConverter(trainer_config, mesh)
294296
with timer("Overall Conversion"):
@@ -317,6 +319,10 @@ def validate_converter(argv) -> None:
317319
"gpu_memory_utilization": getattr(sampler_config, "hbm_utilization_vllm", 0.5),
318320
"async_scheduling": getattr(sampler_config, "async_scheduling", False),
319321
}
322+
# Conditionally add max_num_batched_tokens only for qwen3.5
323+
if trainer_config.model_name == "qwen3.5-35b-a3b":
324+
vllm_kwargs["max_num_batched_tokens"] = 16384
325+
320326
if multislice:
321327
# Pin vLLM to its assigned sampler devices so it doesn't overlap with trainer.
322328
vllm_kwargs["additional_config"] = {

0 commit comments

Comments
 (0)