Skip to content

Commit 0101424

Browse files
initial commit for qwen3.5 moe weight sync script
fix weight sync script Onboard model to validator script and rollout file fix param mappings and converter logic to get weight conversion working Working script for weight sync Ran linter Ran linter Add max_num_batched_tokens only for qwen3.5 Ran linter Remove hardcoded configs and use model specific ones Ran linter
1 parent ff05b79 commit 0101424

3 files changed

Lines changed: 287 additions & 10 deletions

File tree

src/maxtext/integration/vllm/maxtext_vllm_rollout.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,15 @@
3535
from tunix.rl.rollout import base_rollout, vllm_rollout
3636

3737
from maxtext.integration.vllm.torchax_converter.qwen3_moe import Qwen3MaxTextToVLLMConverter
38+
from maxtext.integration.vllm.torchax_converter.qwen35_moe import Qwen35MaxTextToVLLMConverter
3839

3940

4041
def _create_model_converter(model_name: str, config: Any, mesh: jax.sharding.Mesh):
4142
"""Instantiate the converter for a MaxText model name."""
4243
if model_name in {"qwen3-30b-a3b", "qwen3-30b-a3b-base", "qwen3-235b-a22b"}:
4344
return Qwen3MaxTextToVLLMConverter(config=config, mesh=mesh)
45+
elif model_name in {"qwen3.5-35b-a3b"}:
46+
return Qwen35MaxTextToVLLMConverter(config=config, mesh=mesh)
4447

4548
raise ValueError(f"No MaxText->vLLM converter registered for model {model_name!r}.")
4649

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Qwen 3.5 MaxText to vLLM Converter (Supports 35B MoE Hybrid Architecture)."""
16+
17+
import gc
18+
import logging
19+
import jax
20+
import jax.numpy as jnp
21+
22+
from maxtext.integration.vllm.torchax_converter.base import BaseMaxTextToVLLMConverter, timer, GREEN, RESET
23+
24+
25+
class Qwen35MaxTextToVLLMConverter(BaseMaxTextToVLLMConverter):
26+
"""Converts MaxText Qwen3.5 (Scanned Block) layout to vLLM execution layout."""
27+
28+
def convert(self, model_state: dict):
29+
logging.info("\n%sStarting Qwen 3.5 Conversion (Hybrid MoE)...%s", GREEN, RESET)
30+
self.vllm_state = {}
31+
32+
# Generalize architecture slots dynamically via config
33+
self.num_slots = getattr(self.config, "inhomogeneous_layer_cycle_interval", 4)
34+
self.num_reps = self.num_layers // self.num_slots
35+
36+
with timer("Convert Global Weights"):
37+
self._convert_global(model_state)
38+
39+
with timer("Convert Hybrid Attention Weights"):
40+
self._convert_attn(model_state)
41+
42+
with timer("Convert MoE Weights"):
43+
self._convert_moe(model_state)
44+
45+
# Protect JAX compilation by enforcing bfloat16 without iterating dict keys directly
46+
self.vllm_state = {key: weight.astype(jnp.bfloat16) for key, weight in self.vllm_state.items()}
47+
48+
return self.vllm_state
49+
50+
def _convert_global(self, params):
51+
self.vllm_state["vllm_model.language_model.model.embed_tokens.weight"] = jnp.array(
52+
params["base"]["token_embedder"]["embedding"]
53+
)
54+
self.vllm_state["vllm_model.language_model.model.norm.weight"] = jnp.array(
55+
params["base"]["decoder"]["decoder_norm"]["scale"]
56+
)
57+
self.vllm_state["vllm_model.language_model.lm_head.weight"] = jnp.transpose(
58+
params["base"]["decoder"]["logits_dense"]["kernel"], (1, 0)
59+
)
60+
61+
def _convert_attn(self, params):
62+
decoder = params["base"]["decoder"]
63+
blocks = decoder.get("scanned_blocks", decoder.get("layers"))
64+
slot_prefix = "layers" if "scanned_blocks" in decoder else "layer"
65+
66+
@jax.jit
67+
def _unstack_rep(x):
68+
return jnp.unstack(x, axis=1)
69+
70+
for slot in range(self.num_slots):
71+
is_full_attention = slot == self.num_slots - 1
72+
slot_data = blocks[f"{slot_prefix}_{slot}"]
73+
74+
pre_ln = _unstack_rep(slot_data["input_layernorm"]["scale"])
75+
post_ln = _unstack_rep(slot_data["post_attention_layernorm"]["scale"])
76+
77+
if is_full_attention:
78+
attn = slot_data["attention"]["attention"]
79+
80+
q_layers = jnp.unstack(jnp.transpose(attn["query"]["kernel"], (1, 0, 2, 3)), axis=0)
81+
k_layers = jnp.unstack(jnp.transpose(attn["key"]["kernel"], (1, 0, 2, 3)), axis=0)
82+
v_layers = jnp.unstack(jnp.transpose(attn["value"]["kernel"], (1, 0, 2, 3)), axis=0)
83+
o_layers = jnp.unstack(attn["out"]["kernel"], axis=1)
84+
85+
qnorm_layers = _unstack_rep(attn["query_norm"]["scale"])
86+
knorm_layers = _unstack_rep(attn["key_norm"]["scale"])
87+
88+
for rep in range(self.num_reps):
89+
i = rep * self.num_slots + slot
90+
prefix = f"vllm_model.language_model.model.layers.{i}"
91+
92+
self.vllm_state[f"{prefix}.input_layernorm.weight"] = pre_ln[rep]
93+
self.vllm_state[f"{prefix}.post_attention_layernorm.weight"] = post_ln[rep]
94+
95+
q, k, v = q_layers[rep], k_layers[rep], v_layers[rep]
96+
97+
q_T = jnp.transpose(q, (1, 2, 0))
98+
k_T = jnp.transpose(k, (1, 2, 0))
99+
v_T = jnp.transpose(v, (1, 2, 0))
100+
101+
tp_size = self.vllm_tp
102+
q_tp_shards = jnp.split(q_T.reshape(-1, q.shape[0]), tp_size, axis=0)
103+
k_tp_shards = jnp.split(k_T.reshape(-1, k.shape[0]), tp_size, axis=0)
104+
v_tp_shards = jnp.split(v_T.reshape(-1, v.shape[0]), tp_size, axis=0)
105+
106+
tp_interleaved = [
107+
jnp.concatenate([q_tp_shards[t], k_tp_shards[t], v_tp_shards[t]], axis=0) for t in range(tp_size)
108+
]
109+
110+
self.vllm_state[f"{prefix}.self_attn.qkv_proj.weight"] = jnp.concatenate(tp_interleaved, axis=0)
111+
self.vllm_state[f"{prefix}.self_attn.o_proj.weight"] = jnp.transpose(o_layers[rep], (1, 0))
112+
self.vllm_state[f"{prefix}.self_attn.q_norm.weight"] = qnorm_layers[rep]
113+
self.vllm_state[f"{prefix}.self_attn.k_norm.weight"] = knorm_layers[rep]
114+
115+
else:
116+
gdn = slot_data["attention"]
117+
qkvz_layers = jnp.unstack(gdn["in_proj_qkvz"]["kernel"], axis=1)
118+
ba_layers = jnp.unstack(gdn["in_proj_ba"]["kernel"], axis=1)
119+
out_layers = jnp.unstack(gdn["out_proj"]["kernel"], axis=1)
120+
conv_layers = jnp.unstack(gdn["conv1d"]["kernel"], axis=1)
121+
122+
A_log_layers = jnp.unstack(gdn["A_log"], axis=1)
123+
dt_bias_layers = jnp.unstack(gdn["dt_bias"], axis=1)
124+
gdn_norm_layers = _unstack_rep(gdn["norm"]["rms_norm"]["scale"])
125+
126+
for rep in range(self.num_reps):
127+
i = rep * self.num_slots + slot
128+
prefix = f"vllm_model.language_model.model.layers.{i}"
129+
130+
self.vllm_state[f"{prefix}.input_layernorm.weight"] = pre_ln[rep]
131+
self.vllm_state[f"{prefix}.post_attention_layernorm.weight"] = post_ln[rep]
132+
133+
# Extract MaxText GDN QKVZ Layout dynamically via config
134+
H_k = getattr(self.config, "gdn_num_key_heads", 16)
135+
H_v = getattr(self.config, "gdn_num_value_heads", 32)
136+
D_k = getattr(self.config, "gdn_key_head_dim", 128)
137+
D_v = getattr(self.config, "gdn_value_head_dim", 128)
138+
V_per_K = H_v // H_k
139+
140+
t_m = jnp.transpose(qkvz_layers[rep], (1, 0))
141+
block_size = D_k + D_k + V_per_K * D_v + V_per_K * D_v
142+
t_r = t_m.reshape(H_k, block_size, -1)
143+
144+
q = t_r[:, :D_k, :].reshape(H_k * D_k, -1)
145+
k = t_r[:, D_k : 2 * D_k, :].reshape(H_k * D_k, -1)
146+
v = t_r[:, 2 * D_k : 2 * D_k + V_per_K * D_v, :].reshape(H_v * D_v, -1)
147+
z = t_r[:, 2 * D_k + V_per_K * D_v :, :].reshape(H_v * D_v, -1)
148+
149+
tp_size = self.vllm_tp
150+
q_shards = jnp.split(q, tp_size, axis=0)
151+
k_shards = jnp.split(k, tp_size, axis=0)
152+
v_shards = jnp.split(v, tp_size, axis=0)
153+
z_shards = jnp.split(z, tp_size, axis=0)
154+
155+
qkvz_interleaved = [
156+
jnp.concatenate([q_shards[s], k_shards[s], v_shards[s], z_shards[s]], axis=0) for s in range(tp_size)
157+
]
158+
self.vllm_state[f"{prefix}.linear_attn.in_proj_qkvz.weight"] = jnp.concatenate(qkvz_interleaved, axis=0)
159+
160+
# Extract MaxText GDN BA Layout
161+
t_m_ba = jnp.transpose(ba_layers[rep], (1, 0))
162+
block_size_ba = V_per_K * 2
163+
t_r_ba = t_m_ba.reshape(H_k, block_size_ba, -1)
164+
165+
b = t_r_ba[:, :V_per_K, :].reshape(H_v, -1)
166+
a = t_r_ba[:, V_per_K:, :].reshape(H_v, -1)
167+
168+
b_shards = jnp.split(b, tp_size, axis=0)
169+
a_shards = jnp.split(a, tp_size, axis=0)
170+
171+
ba_interleaved = [jnp.concatenate([b_shards[s], a_shards[s]], axis=0) for s in range(tp_size)]
172+
self.vllm_state[f"{prefix}.linear_attn.in_proj_ba.weight"] = jnp.concatenate(ba_interleaved, axis=0)
173+
174+
self.vllm_state[f"{prefix}.linear_attn.out_proj.weight"] = jnp.transpose(out_layers[rep], (1, 0))
175+
self.vllm_state[f"{prefix}.linear_attn.conv1d.weight"] = jnp.transpose(conv_layers[rep], (2, 1, 0))
176+
self.vllm_state[f"{prefix}.linear_attn.A_log"] = A_log_layers[rep]
177+
self.vllm_state[f"{prefix}.linear_attn.dt_bias"] = dt_bias_layers[rep]
178+
self.vllm_state[f"{prefix}.linear_attn.norm.weight"] = gdn_norm_layers[rep]
179+
180+
gc.collect()
181+
182+
def _convert_moe(self, params):
183+
decoder = params["base"]["decoder"]
184+
blocks = decoder.get("scanned_blocks", decoder.get("layers"))
185+
slot_prefix = "layers" if "scanned_blocks" in decoder else "layer"
186+
187+
for slot in range(self.num_slots):
188+
slot_data = blocks[f"{slot_prefix}_{slot}"]
189+
190+
if "mlp" not in slot_data or "routed_experts" not in slot_data["mlp"]:
191+
continue
192+
193+
mlp_block = slot_data["mlp"]
194+
routed = mlp_block["routed_experts"]
195+
has_shared = "shared_expert" in mlp_block
196+
197+
router_weights = jnp.unstack(jnp.transpose(routed["gate"]["kernel"], (1, 2, 0)), axis=0)
198+
199+
# -------------------------------------------------------------
200+
# Fusing, TP Interleaving, and TPU GMM Alignment for W1 and W3
201+
# -------------------------------------------------------------
202+
wi_0 = jnp.transpose(routed["wi_0"], (1, 0, 2, 3))
203+
wi_1 = jnp.transpose(routed["wi_1"], (1, 0, 2, 3))
204+
205+
num_reps, num_experts, d_model, d_inner = wi_0.shape
206+
tp_size = self.vllm_tp
207+
208+
# vLLM's TPU Grouped GEMM kernel requires 128-alignment per expert chunk
209+
chunk_size = d_inner // tp_size
210+
padded_chunk_size = ((chunk_size + 127) // 128) * 128
211+
pad_amount = padded_chunk_size - chunk_size
212+
213+
w1_chunks = wi_0.reshape(num_reps, num_experts, d_model, tp_size, chunk_size)
214+
w3_chunks = wi_1.reshape(num_reps, num_experts, d_model, tp_size, chunk_size)
215+
216+
# Apply padding if running on a topology that splinters chunks below 128 (e.g. TP=8)
217+
if pad_amount > 0:
218+
w1_chunks = jnp.pad(w1_chunks, ((0, 0), (0, 0), (0, 0), (0, 0), (0, pad_amount)))
219+
w3_chunks = jnp.pad(w3_chunks, ((0, 0), (0, 0), (0, 0), (0, 0), (0, pad_amount)))
220+
221+
# Interleave W1 and W3 shards -> Shape: (reps, exp, d_model, tp, 2, padded_chunk)
222+
combined_shards = jnp.stack([w1_chunks, w3_chunks], axis=-2)
223+
224+
# Flatten the TP, 2, and chunk dimensions back into the final inner dimension
225+
gate_up = combined_shards.reshape(num_reps, num_experts, d_model, -1)
226+
w13_layers = jnp.unstack(gate_up, axis=0)
227+
# -------------------------------------------------------------
228+
229+
wo_transposed = jnp.transpose(routed["wo"], (1, 0, 2, 3))
230+
down_layers = jnp.unstack(wo_transposed, axis=0)
231+
232+
# Extract Shared Experts
233+
if has_shared:
234+
shared = mlp_block["shared_expert"]
235+
sh_gate_layers = jnp.unstack(jnp.transpose(shared["wi_0"]["kernel"], (1, 2, 0)), axis=0)
236+
sh_up_layers = jnp.unstack(jnp.transpose(shared["wi_1"]["kernel"], (1, 2, 0)), axis=0)
237+
sh_down_layers = jnp.unstack(jnp.transpose(shared["wo"]["kernel"], (1, 2, 0)), axis=0)
238+
239+
if "shared_expert_gate" in mlp_block:
240+
sh_gate_router_layers = jnp.unstack(jnp.transpose(mlp_block["shared_expert_gate"]["kernel"], (1, 2, 0)), axis=0)
241+
242+
for rep in range(self.num_reps):
243+
i = rep * self.num_slots + slot
244+
p = f"vllm_model.language_model.model.layers.{i}"
245+
246+
self.vllm_state[f"{p}.mlp.gate.weight"] = router_weights[rep]
247+
self.vllm_state[f"{p}.mlp.experts.w13_weight"] = w13_layers[rep]
248+
self.vllm_state[f"{p}.mlp.experts.w2_weight"] = down_layers[rep]
249+
250+
if has_shared:
251+
sh_g, sh_u = sh_gate_layers[rep], sh_up_layers[rep]
252+
sh_per_tp = sh_g.shape[0] // self.vllm_tp
253+
254+
shared_gate_up = jnp.concatenate(
255+
[
256+
sh_g.reshape(self.vllm_tp, sh_per_tp, sh_g.shape[1]),
257+
sh_u.reshape(self.vllm_tp, sh_per_tp, sh_u.shape[1]),
258+
],
259+
axis=1,
260+
).reshape(-1, sh_g.shape[1])
261+
262+
self.vllm_state[f"{p}.mlp.shared_expert.gate_up_proj.weight"] = shared_gate_up
263+
self.vllm_state[f"{p}.mlp.shared_expert.down_proj.weight"] = sh_down_layers[rep]
264+
265+
if "shared_expert_gate" in mlp_block:
266+
self.vllm_state[f"{p}.mlp.shared_expert_gate.weight"] = sh_gate_router_layers[rep]
267+
268+
gc.collect()

src/maxtext/integration/vllm/torchax_converter/validate_converter.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
3. loads the matching vLLM model, and
2121
4. assigns the converted weights before running a short generation check.
2222
23-
python -m maxtext.integration.vllm.torchax_converter.validate_converter \
24-
src/maxtext/configs/post_train/rl.yml model_name=qwen3-30b-a3b \
25-
tokenizer_type=huggingface tokenizer_path=Qwen/Qwen3-30B-A3B \
26-
load_parameters_path=<your_maxtext_checkpoint_path> run_name=qwen3_converter_validation \
27-
per_device_batch_size=1 max_prefill_predict_length=8 max_target_length=16 steps=1 \
28-
scan_layers=true skip_jax_distributed_system=true weight_dtype=bfloat16 \
29-
rollout_tensor_parallelism=4 hbm_utilization_vllm=0.6 async_scheduling=false \
30-
prompt="Paris is" hf_access_token=<token> use_chat_template=true
23+
python -m maxtext.integration.vllm.torchax_converter.validate_converter \
24+
src/maxtext/configs/post_train/rl.yml model_name=qwen3-30b-a3b \
25+
tokenizer_type=huggingface tokenizer_path=Qwen/Qwen3-30B-A3B \
26+
load_parameters_path=<your_maxtext_checkpoint_path> run_name=qwen3_converter_validation \
27+
per_device_batch_size=1 max_prefill_predict_length=8 max_target_length=16 steps=1 \
28+
scan_layers=true skip_jax_distributed_system=true weight_dtype=bfloat16 \
29+
rollout_tensor_parallelism=4 hbm_utilization_vllm=0.6 async_scheduling=false \
30+
prompt="Paris is" hf_access_token=<token> use_chat_template=true
3131
For multislice (e.g. 2x128-device slices), additionally pass:
3232
num_trainer_slices=1 num_samplers_slices=1
3333
@@ -70,6 +70,7 @@
7070
from maxtext.integration.vllm.torchax_converter.base import timer
7171
from maxtext.integration.vllm.torchax_converter.gemma4_moe import Gemma4MaxTextToVLLMConverter
7272
from maxtext.integration.vllm.torchax_converter.qwen3_moe import Qwen3MaxTextToVLLMConverter
73+
from maxtext.integration.vllm.torchax_converter.qwen35_moe import Qwen35MaxTextToVLLMConverter
7374
from maxtext.utils import model_creation_utils
7475

7576
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
@@ -81,6 +82,7 @@
8182
"qwen3-30b-a3b-base": "Qwen/Qwen3-30B-A3B",
8283
"qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B",
8384
"gemma4-26b": "google/gemma-4-26B-A4B",
85+
"qwen3.5-35b-a3b": "Qwen/Qwen3.5-35B-A3B",
8486
# Add more mappings as needed
8587
}
8688

@@ -150,8 +152,6 @@ def _log_weight_stats(converted_state: dict, vllm_state: dict, compare: bool) ->
150152
conv = np.array(weight_array, dtype=np.float32)
151153
# rel_frobenius = ||converted - ref||_F / ||ref||_F.
152154
# ~0 means bit-for-bit correct; ~1 or above means the content is wrong.
153-
# Unlike mean/std/min/max, this catches permutation and transposition bugs
154-
# because it is order-sensitive.
155155
rel_frob = float(np.linalg.norm(conv - ref)) / (float(np.linalg.norm(ref)) + 1e-8)
156156
logging.info(" [VLLM-REF] %s | %s", key, _weight_stats_str(vllm_state[key]))
157157
logging.info(" [DIFF] %s | rel_frobenius=%.6f", key, rel_frob)
@@ -289,6 +289,8 @@ def validate_converter(argv) -> None:
289289

290290
if trainer_config.model_name.startswith("gemma4"):
291291
converter = Gemma4MaxTextToVLLMConverter(trainer_config, mesh)
292+
elif trainer_config.model_name.startswith("qwen3.5"):
293+
converter = Qwen35MaxTextToVLLMConverter(trainer_config, mesh)
292294
else:
293295
converter = Qwen3MaxTextToVLLMConverter(trainer_config, mesh)
294296
with timer("Overall Conversion"):
@@ -317,6 +319,10 @@ def validate_converter(argv) -> None:
317319
"gpu_memory_utilization": getattr(sampler_config, "hbm_utilization_vllm", 0.5),
318320
"async_scheduling": getattr(sampler_config, "async_scheduling", False),
319321
}
322+
# Conditionally add max_num_batched_tokens only for qwen3.5
323+
if trainer_config.model_name == "qwen3.5-35b-a3b":
324+
vllm_kwargs["max_num_batched_tokens"] = 16384
325+
320326
if multislice:
321327
# Pin vLLM to its assigned sampler devices so it doesn't overlap with trainer.
322328
vllm_kwargs["additional_config"] = {

0 commit comments

Comments
 (0)