Skip to content

Commit 27f294d

Browse files
committed
Qwen3.5 vision layers ckpt conversion and decode
1 parent d8763ef commit 27f294d

10 files changed

Lines changed: 266 additions & 52 deletions

File tree

benchmarks/multimodal/multimodal_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def main(config, local_args):
230230
second_per_grids=processor_output.video_second_per_grid, # pytype: disable=attribute-error
231231
spatial_merge_size=config.spatial_merge_size_for_vit, # pytype: disable=attribute-error
232232
position_id_per_seconds=config.position_id_per_seconds,
233+
config=config,
233234
)
234235

235236
if true_length > max_prefill_predict_length:

src/maxtext/checkpoint_conversion/utils/param_mapping.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,64 @@ def QWEN3_5_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fals
10421042
): f"model.language_model.layers.{i}.mlp.experts.gate_up_proj",
10431043
}
10441044
)
1045+
1046+
# Vision mapping for Qwen3.5
1047+
if maxtext_config.use_multimodal and "vision_config" in config:
1048+
vision_config = config["vision_config"]
1049+
n_vision_layers = vision_config["depth"]
1050+
1051+
# Vision patch embedding
1052+
mapping["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-patch_embed-proj-kernel"] = (
1053+
"model.visual.patch_embed.proj.weight"
1054+
)
1055+
mapping["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-patch_embed-proj-bias"] = (
1056+
"model.visual.patch_embed.proj.bias"
1057+
)
1058+
1059+
# Vision positional embedding
1060+
mapping["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-pos_embed_interpolate-pos_embed"] = (
1061+
"model.visual.pos_embed.weight"
1062+
)
1063+
1064+
# Vision blocks
1065+
for i in range(n_vision_layers):
1066+
prefix = f"params-vision_encoder-Qwen3_5MoeVisionEncoder_0-blocks_{i}"
1067+
hf_prefix = f"model.visual.blocks.{i}"
1068+
1069+
# Layer norms
1070+
mapping[f"{prefix}-ln1-scale"] = f"{hf_prefix}.norm1.weight"
1071+
mapping[f"{prefix}-ln1-bias"] = f"{hf_prefix}.norm1.bias"
1072+
mapping[f"{prefix}-ln2-scale"] = f"{hf_prefix}.norm2.weight"
1073+
mapping[f"{prefix}-ln2-bias"] = f"{hf_prefix}.norm2.bias"
1074+
1075+
# Attention
1076+
mapping[f"{prefix}-attn-attn-query-kernel"] = f"{hf_prefix}.attn.qkv.weight"
1077+
mapping[f"{prefix}-attn-attn-query-bias"] = f"{hf_prefix}.attn.qkv.bias"
1078+
mapping[f"{prefix}-attn-attn-key-kernel"] = f"{hf_prefix}.attn.qkv.weight"
1079+
mapping[f"{prefix}-attn-attn-key-bias"] = f"{hf_prefix}.attn.qkv.bias"
1080+
mapping[f"{prefix}-attn-attn-value-kernel"] = f"{hf_prefix}.attn.qkv.weight"
1081+
mapping[f"{prefix}-attn-attn-value-bias"] = f"{hf_prefix}.attn.qkv.bias"
1082+
mapping[f"{prefix}-attn-attn-out-kernel"] = f"{hf_prefix}.attn.proj.weight"
1083+
mapping[f"{prefix}-attn-attn-out-bias"] = f"{hf_prefix}.attn.proj.bias"
1084+
1085+
# MLP
1086+
mapping[f"{prefix}-mlp-kernel"] = f"{hf_prefix}.mlp.linear_fc1.weight"
1087+
mapping[f"{prefix}-mlp-bias"] = f"{hf_prefix}.mlp.linear_fc1.bias"
1088+
mapping[f"{prefix}-mlp_out-kernel"] = f"{hf_prefix}.mlp.linear_fc2.weight"
1089+
mapping[f"{prefix}-mlp_out-bias"] = f"{hf_prefix}.mlp.linear_fc2.bias"
1090+
1091+
# Vision projector (final merger)
1092+
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-ln_q-scale"] = "model.visual.merger.norm.weight"
1093+
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-ln_q-bias"] = "model.visual.merger.norm.bias"
1094+
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-kernel"] = (
1095+
"model.visual.merger.linear_fc1.weight"
1096+
)
1097+
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-bias"] = "model.visual.merger.linear_fc1.bias"
1098+
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-kernel"] = (
1099+
"model.visual.merger.linear_fc2.weight"
1100+
)
1101+
mapping["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-bias"] = "model.visual.merger.linear_fc2.bias"
1102+
10451103
return mapping
10461104

10471105

@@ -1214,6 +1272,92 @@ def concat_ba_and_transpose(input_tensor, target_shape=None):
12141272
hooks[(f"{mlp_prefix}-routed_experts-wi_0", f"{mlp_prefix}-routed_experts-wi_1")] = process_wi_0_wi_1
12151273
hooks[f"{mlp_prefix}-routed_experts-wo"] = transpose_expert
12161274

1275+
# Vision hooks for Qwen3.5
1276+
vision_config = config.get("vision_config", None)
1277+
if vision_config and maxtext_config.use_multimodal:
1278+
n_vision_layers = vision_config["depth"]
1279+
hidden_size = vision_config["hidden_size"]
1280+
1281+
def reshape_kernel_vision(input_tensor, target_shape):
1282+
if saving_to_hf:
1283+
flipped_target_shape = np.flip(np.array(target_shape))
1284+
return input_tensor.reshape(flipped_target_shape).T
1285+
else:
1286+
return input_tensor.T.reshape(target_shape)
1287+
1288+
def reshape_conv3d_patch_embed(input_tensor, target_shape):
1289+
if saving_to_hf:
1290+
return input_tensor.transpose(4, 3, 0, 1, 2)
1291+
else:
1292+
return input_tensor.transpose(2, 3, 4, 1, 0)
1293+
1294+
def split_qkv_query(input_tensor, target_shape):
1295+
if saving_to_hf:
1296+
raise NotImplementedError("Use fusion hook for MaxText->HF")
1297+
else:
1298+
q_weight = input_tensor[:hidden_size, :]
1299+
return q_weight.T.reshape(target_shape)
1300+
1301+
def split_qkv_key(input_tensor, target_shape):
1302+
if saving_to_hf:
1303+
raise NotImplementedError("Use fusion hook for MaxText->HF")
1304+
else:
1305+
k_weight = input_tensor[hidden_size : 2 * hidden_size, :]
1306+
return k_weight.T.reshape(target_shape)
1307+
1308+
def split_qkv_value(input_tensor, target_shape):
1309+
if saving_to_hf:
1310+
raise NotImplementedError("Use fusion hook for MaxText->HF")
1311+
else:
1312+
v_weight = input_tensor[2 * hidden_size :, :]
1313+
return v_weight.T.reshape(target_shape)
1314+
1315+
def split_qkv_bias_query(input_tensor, target_shape):
1316+
if saving_to_hf:
1317+
raise NotImplementedError("Use fusion hook for MaxText->HF")
1318+
else:
1319+
q_bias = input_tensor[:hidden_size]
1320+
return q_bias.reshape(target_shape)
1321+
1322+
def split_qkv_bias_key(input_tensor, target_shape):
1323+
if saving_to_hf:
1324+
raise NotImplementedError("Use fusion hook for MaxText->HF")
1325+
else:
1326+
k_bias = input_tensor[hidden_size : 2 * hidden_size]
1327+
return k_bias.reshape(target_shape)
1328+
1329+
def split_qkv_bias_value(input_tensor, target_shape):
1330+
if saving_to_hf:
1331+
raise NotImplementedError("Use fusion hook for MaxText->HF")
1332+
else:
1333+
v_bias = input_tensor[2 * hidden_size :]
1334+
return v_bias.reshape(target_shape)
1335+
1336+
def reshape_vision_attn_out(input_tensor, target_shape):
1337+
if saving_to_hf:
1338+
return input_tensor.reshape(hidden_size, hidden_size).T
1339+
else:
1340+
return input_tensor.T.reshape(target_shape)
1341+
1342+
# Apply vision hooks
1343+
hooks["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-patch_embed-proj-kernel"] = reshape_conv3d_patch_embed
1344+
1345+
for i in range(n_vision_layers):
1346+
prefix = f"params-vision_encoder-Qwen3_5MoeVisionEncoder_0-blocks_{i}"
1347+
hooks[f"{prefix}-attn-attn-query-kernel"] = split_qkv_query
1348+
hooks[f"{prefix}-attn-attn-query-bias"] = split_qkv_bias_query
1349+
hooks[f"{prefix}-attn-attn-key-kernel"] = split_qkv_key
1350+
hooks[f"{prefix}-attn-attn-key-bias"] = split_qkv_bias_key
1351+
hooks[f"{prefix}-attn-attn-value-kernel"] = split_qkv_value
1352+
hooks[f"{prefix}-attn-attn-value-bias"] = split_qkv_bias_value
1353+
hooks[f"{prefix}-attn-attn-out-kernel"] = reshape_vision_attn_out
1354+
hooks[f"{prefix}-mlp-kernel"] = reshape_kernel_vision
1355+
hooks[f"{prefix}-mlp_out-kernel"] = reshape_kernel_vision
1356+
1357+
# Vision projector
1358+
hooks["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-kernel"] = reshape_kernel_vision
1359+
hooks["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-kernel"] = reshape_kernel_vision
1360+
12171361
return hooks
12181362

12191363

src/maxtext/configs/models/qwen3.5-35b-a3b.yml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,23 @@ partial_rotary_factor: 0.25
4848

4949
# General Model Settings
5050
enable_dropout: False
51+
52+
# Vision Encoder Configuration (need to set use_multimodal=true)
53+
# Based on Qwen3.5 MoE Vision Model Config
54+
image_size_for_vit: 768
55+
hidden_size_for_vit: 1152
56+
intermediate_size_for_vit: 4304
57+
num_attention_heads_for_vit: 16
58+
num_hidden_layers_for_vit: 27
59+
num_channels_for_vit: 3
60+
patch_size_for_vit: 16
61+
temporal_patch_size_for_vit: 2
62+
spatial_merge_size_for_vit: 2
63+
out_hidden_size_for_vit: 2048 # Projects to decoder emb_dim (2048)
64+
num_position_embeddings_for_vit: 2304
65+
deepstack_visual_indexes_for_vit: [] # No deepstack for Qwen3.5 VL
66+
rope_theta_for_vit: 10000
67+
68+
# MRoPE Settings (Multi-dimensional RoPE for multimodal)
69+
use_mrope: true
70+
mrope_section: [11, 11, 10]

src/maxtext/configs/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ class ProfilerType(str, Enum):
262262
"qwen3-next-80b-a3b",
263263
"qwen3-omni-30b-a3b",
264264
"qwen3-custom-30b-a3b",
265-
"qwen3.5-397b-a17b",
266265
"qwen3.5-35b-a3b",
266+
"qwen3.5-397b-a17b",
267267
"gpt3-175b",
268268
"gpt3-22b",
269269
"gpt3-6b",
@@ -2927,6 +2927,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
29272927
"llama4-17b-16e",
29282928
"llama4-17b-128e",
29292929
"qwen3-omni-30b-a3b",
2930+
"qwen3.5-35b-a3b",
29302931
"qwen3.5-397b-a17b",
29312932
)
29322933
if self.model_name not in valid_mm_models and self.model_name != "default":

src/maxtext/inference/decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def main(argv: Sequence[str]) -> None:
154154
second_per_grids=processor_outputs.video_second_per_grid, # pytype: disable=attribute-error
155155
spatial_merge_size=config.spatial_merge_size_for_vit, # pytype: disable=attribute-error
156156
position_id_per_seconds=config.position_id_per_seconds,
157+
config=config,
157158
)
158159

159160
assert (

src/maxtext/layers/decoders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,7 @@ def _apply_embedding(
660660
"llama4-17b-16e",
661661
"llama4-17b-128e",
662662
"qwen3-omni-30b-a3b",
663+
"qwen3.5-35b-a3b",
663664
"qwen3.5-397b-a17b",
664665
]:
665666
y = mm_utils.merge_mm_embeddings(
@@ -673,7 +674,7 @@ def _apply_embedding(
673674
raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}")
674675

675676
if video_embeddings is not None and cfg.use_multimodal:
676-
if cfg.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
677+
if cfg.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
677678
y = mm_utils.merge_mm_embeddings(
678679
text_embeddings=y,
679680
multimodal_embeddings=video_embeddings,

src/maxtext/layers/encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _setup_vision_encoder_layers(self):
7070
self, projector_name, gemma4_vision.Gemma4VisionProjector(config=self.config, mesh=self.mesh, rngs=self.rngs)
7171
)
7272
return encoder_name, projector_name
73-
elif self.config.model_name in ["qwen3.5-397b-a17b"]:
73+
elif self.config.model_name in ["qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
7474
from maxtext.models import qwen3_5_vision # pylint: disable=import-outside-toplevel
7575

7676
encoder_name = "Qwen3_5MoeVisionEncoder_0"

src/maxtext/models/qwen3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,8 @@ def __init__(
734734
use_qk_norm=cfg.use_qk_norm,
735735
query_pre_attn_scalar=scaling_factor,
736736
model_mode=model_mode,
737+
use_mrope=cfg.use_mrope,
738+
mrope_section=cfg.mrope_section,
737739
rngs=rngs,
738740
)
739741

src/maxtext/multimodal/processor.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def preprocess_mm_data(config):
4444

4545
images = [mm_utils.load_image_from_path(p) for p in config.image_path.split(",")]
4646
processor_outputs = preprocess_mm_data_llama4(images)
47-
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
47+
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
4848
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni # pylint: disable=import-outside-toplevel
4949

5050
processor_outputs = preprocess_mm_data_qwen3_omni(config)
@@ -68,7 +68,7 @@ def preprocess_image_for_training(image, model_name):
6868
from maxtext.multimodal.processor_llama4 import preprocess_mm_data_llama4 # pylint: disable=import-outside-toplevel
6969

7070
return preprocess_mm_data_llama4(image)
71-
elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
71+
elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
7272
from maxtext.multimodal.processor_qwen3_omni import preprocess_mm_data_qwen3_omni_for_training # pylint: disable=import-outside-toplevel
7373

7474
return preprocess_mm_data_qwen3_omni_for_training(image)
@@ -90,7 +90,7 @@ def get_image_offsets(config, processor_output: mm_utils.PreprocessorOutput | No
9090
from maxtext.multimodal.processor_llama4 import get_image_offsets_llama4 # pylint: disable=import-outside-toplevel
9191

9292
return get_image_offsets_llama4(processor_output)
93-
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
93+
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
9494
from maxtext.multimodal.processor_qwen3_omni import get_mm_offsets_qwen3_omni # pylint: disable=import-outside-toplevel
9595

9696
return get_mm_offsets_qwen3_omni(config, processor_output)
@@ -112,7 +112,7 @@ def reformat_prompt(prompt, image_placeholder, model_name, num_images, video_pla
112112
from maxtext.multimodal.processor_llama4 import reformat_prompt_llama4 # pylint: disable=import-outside-toplevel
113113

114114
return reformat_prompt_llama4(prompt, image_placeholder, num_images)
115-
elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
115+
elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
116116
from maxtext.multimodal.processor_qwen3_omni import reformat_prompt_qwen3_omni # pylint: disable=import-outside-toplevel
117117

118118
return reformat_prompt_qwen3_omni(
@@ -137,7 +137,7 @@ def reformat_response(response, model_name):
137137
elif model_name in ["gemma4-26b", "gemma4-31b", "gemma4-e2b", "gemma4-e4b"]:
138138
formatted_response = f"{response}<end_of_turn>"
139139
return formatted_response
140-
elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
140+
elif model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
141141
formatted_response = f"{response}<|im_end|>"
142142
return formatted_response
143143
else:
@@ -158,7 +158,7 @@ def prepare_text_for_image_fusion(tokens, config, processor_output=None):
158158
from maxtext.multimodal.processor_llama4 import add_extra_tokens_for_images_llama4 # pylint: disable=import-outside-toplevel
159159

160160
return add_extra_tokens_for_images_llama4(tokens, processor_output)
161-
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
161+
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
162162
from maxtext.multimodal.processor_qwen3_omni import add_extra_tokens_for_qwen3_omni # pylint: disable=import-outside-toplevel
163163

164164
return add_extra_tokens_for_qwen3_omni(tokens, config, processor_output)
@@ -181,7 +181,7 @@ def get_dummy_image_shape_for_init(model_name, batch_size=1, num_image_per_seque
181181
from maxtext.multimodal.processor_llama4 import get_dummy_image_shape_for_init_llama4 # pylint: disable=import-outside-toplevel
182182

183183
image_shape = get_dummy_image_shape_for_init_llama4(batch_size, num_image_per_sequence)
184-
elif model_name.startswith("qwen3-omni-30b-a3b") or model_name.startswith("qwen3.5-397b-a17b"):
184+
elif model_name.startswith("qwen3-omni-30b-a3b") or model_name.startswith("qwen3.5"):
185185
from maxtext.multimodal.processor_qwen3_omni import get_dummy_image_shape_for_init_qwen3_omni # pylint: disable=import-outside-toplevel
186186

187187
image_shape = get_dummy_image_shape_for_init_qwen3_omni(batch_size)
@@ -222,22 +222,26 @@ def get_bidirectional_mask_vision(config, decoder_input_tokens, is_video: bool =
222222
from maxtext.multimodal.processor_llama4 import LLAMA4_PATCH_TOKEN # pylint: disable=import-outside-toplevel
223223

224224
bidirectional_mask_vision = decoder_input_tokens == LLAMA4_PATCH_TOKEN
225-
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-397b-a17b"]:
226-
from maxtext.multimodal.processor_qwen3_omni import QWEN3_OMNI_IMAGE_TOKEN, QWEN3_OMNI_VIDEO_TOKEN # pylint: disable=import-outside-toplevel
225+
elif config.model_name in ["qwen3-omni-30b-a3b", "qwen3.5-35b-a3b", "qwen3.5-397b-a17b"]:
226+
from maxtext.multimodal.processor_qwen3_omni import QwenTokens # pylint: disable=import-outside-toplevel
227+
228+
tokens = QwenTokens(config)
227229

228230
if is_video:
229-
bidirectional_mask_vision = decoder_input_tokens == QWEN3_OMNI_VIDEO_TOKEN
231+
bidirectional_mask_vision = decoder_input_tokens == tokens.video_pad
230232
else:
231-
bidirectional_mask_vision = decoder_input_tokens == QWEN3_OMNI_IMAGE_TOKEN
233+
bidirectional_mask_vision = decoder_input_tokens == tokens.image_pad
232234
return bidirectional_mask_vision
233235

234236

235237
def get_bidirectional_mask_audio(config, decoder_input_tokens):
236238
"""Get the bidirectional mask for specific models."""
237239
bidirectional_mask_audio = None
238240
if config.model_name in ["qwen3-omni-30b-a3b"]:
239-
from maxtext.multimodal.processor_qwen3_omni import QWEN3_OMNI_AUDIO_TOKEN # pylint: disable=import-outside-toplevel
241+
from maxtext.multimodal.processor_qwen3_omni import QwenTokens # pylint: disable=import-outside-toplevel
242+
243+
tokens = QwenTokens(config)
240244

241245
# Create bidirectional_mask for audio token merging
242-
bidirectional_mask_audio = decoder_input_tokens == QWEN3_OMNI_AUDIO_TOKEN
246+
bidirectional_mask_audio = decoder_input_tokens == tokens.audio_pad
243247
return bidirectional_mask_audio

0 commit comments

Comments
 (0)