@@ -1042,6 +1042,64 @@ def QWEN3_5_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fals
10421042 ): f"model.language_model.layers.{ i } .mlp.experts.gate_up_proj" ,
10431043 }
10441044 )
1045+
1046+ # Vision mapping for Qwen3.5
1047+ if maxtext_config .use_multimodal and "vision_config" in config :
1048+ vision_config = config ["vision_config" ]
1049+ n_vision_layers = vision_config ["depth" ]
1050+
1051+ # Vision patch embedding
1052+ mapping ["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-patch_embed-proj-kernel" ] = (
1053+ "model.visual.patch_embed.proj.weight"
1054+ )
1055+ mapping ["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-patch_embed-proj-bias" ] = (
1056+ "model.visual.patch_embed.proj.bias"
1057+ )
1058+
1059+ # Vision positional embedding
1060+ mapping ["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-pos_embed_interpolate-pos_embed" ] = (
1061+ "model.visual.pos_embed.weight"
1062+ )
1063+
1064+ # Vision blocks
1065+ for i in range (n_vision_layers ):
1066+ prefix = f"params-vision_encoder-Qwen3_5MoeVisionEncoder_0-blocks_{ i } "
1067+ hf_prefix = f"model.visual.blocks.{ i } "
1068+
1069+ # Layer norms
1070+ mapping [f"{ prefix } -ln1-scale" ] = f"{ hf_prefix } .norm1.weight"
1071+ mapping [f"{ prefix } -ln1-bias" ] = f"{ hf_prefix } .norm1.bias"
1072+ mapping [f"{ prefix } -ln2-scale" ] = f"{ hf_prefix } .norm2.weight"
1073+ mapping [f"{ prefix } -ln2-bias" ] = f"{ hf_prefix } .norm2.bias"
1074+
1075+ # Attention
1076+ mapping [f"{ prefix } -attn-attn-query-kernel" ] = f"{ hf_prefix } .attn.qkv.weight"
1077+ mapping [f"{ prefix } -attn-attn-query-bias" ] = f"{ hf_prefix } .attn.qkv.bias"
1078+ mapping [f"{ prefix } -attn-attn-key-kernel" ] = f"{ hf_prefix } .attn.qkv.weight"
1079+ mapping [f"{ prefix } -attn-attn-key-bias" ] = f"{ hf_prefix } .attn.qkv.bias"
1080+ mapping [f"{ prefix } -attn-attn-value-kernel" ] = f"{ hf_prefix } .attn.qkv.weight"
1081+ mapping [f"{ prefix } -attn-attn-value-bias" ] = f"{ hf_prefix } .attn.qkv.bias"
1082+ mapping [f"{ prefix } -attn-attn-out-kernel" ] = f"{ hf_prefix } .attn.proj.weight"
1083+ mapping [f"{ prefix } -attn-attn-out-bias" ] = f"{ hf_prefix } .attn.proj.bias"
1084+
1085+ # MLP
1086+ mapping [f"{ prefix } -mlp-kernel" ] = f"{ hf_prefix } .mlp.linear_fc1.weight"
1087+ mapping [f"{ prefix } -mlp-bias" ] = f"{ hf_prefix } .mlp.linear_fc1.bias"
1088+ mapping [f"{ prefix } -mlp_out-kernel" ] = f"{ hf_prefix } .mlp.linear_fc2.weight"
1089+ mapping [f"{ prefix } -mlp_out-bias" ] = f"{ hf_prefix } .mlp.linear_fc2.bias"
1090+
1091+ # Vision projector (final merger)
1092+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-ln_q-scale" ] = "model.visual.merger.norm.weight"
1093+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-ln_q-bias" ] = "model.visual.merger.norm.bias"
1094+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-kernel" ] = (
1095+ "model.visual.merger.linear_fc1.weight"
1096+ )
1097+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-bias" ] = "model.visual.merger.linear_fc1.bias"
1098+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-kernel" ] = (
1099+ "model.visual.merger.linear_fc2.weight"
1100+ )
1101+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-bias" ] = "model.visual.merger.linear_fc2.bias"
1102+
10451103 return mapping
10461104
10471105
@@ -1214,6 +1272,92 @@ def concat_ba_and_transpose(input_tensor, target_shape=None):
12141272 hooks [(f"{ mlp_prefix } -routed_experts-wi_0" , f"{ mlp_prefix } -routed_experts-wi_1" )] = process_wi_0_wi_1
12151273 hooks [f"{ mlp_prefix } -routed_experts-wo" ] = transpose_expert
12161274
1275+ # Vision hooks for Qwen3.5
1276+ vision_config = config .get ("vision_config" , None )
1277+ if vision_config and maxtext_config .use_multimodal :
1278+ n_vision_layers = vision_config ["depth" ]
1279+ hidden_size = vision_config ["hidden_size" ]
1280+
1281+ def reshape_kernel_vision (input_tensor , target_shape ):
1282+ if saving_to_hf :
1283+ flipped_target_shape = np .flip (np .array (target_shape ))
1284+ return input_tensor .reshape (flipped_target_shape ).T
1285+ else :
1286+ return input_tensor .T .reshape (target_shape )
1287+
1288+ def reshape_conv3d_patch_embed (input_tensor , target_shape ):
1289+ if saving_to_hf :
1290+ return input_tensor .transpose (4 , 3 , 0 , 1 , 2 )
1291+ else :
1292+ return input_tensor .transpose (2 , 3 , 4 , 1 , 0 )
1293+
1294+ def split_qkv_query (input_tensor , target_shape ):
1295+ if saving_to_hf :
1296+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1297+ else :
1298+ q_weight = input_tensor [:hidden_size , :]
1299+ return q_weight .T .reshape (target_shape )
1300+
1301+ def split_qkv_key (input_tensor , target_shape ):
1302+ if saving_to_hf :
1303+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1304+ else :
1305+ k_weight = input_tensor [hidden_size : 2 * hidden_size , :]
1306+ return k_weight .T .reshape (target_shape )
1307+
1308+ def split_qkv_value (input_tensor , target_shape ):
1309+ if saving_to_hf :
1310+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1311+ else :
1312+ v_weight = input_tensor [2 * hidden_size :, :]
1313+ return v_weight .T .reshape (target_shape )
1314+
1315+ def split_qkv_bias_query (input_tensor , target_shape ):
1316+ if saving_to_hf :
1317+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1318+ else :
1319+ q_bias = input_tensor [:hidden_size ]
1320+ return q_bias .reshape (target_shape )
1321+
1322+ def split_qkv_bias_key (input_tensor , target_shape ):
1323+ if saving_to_hf :
1324+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1325+ else :
1326+ k_bias = input_tensor [hidden_size : 2 * hidden_size ]
1327+ return k_bias .reshape (target_shape )
1328+
1329+ def split_qkv_bias_value (input_tensor , target_shape ):
1330+ if saving_to_hf :
1331+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1332+ else :
1333+ v_bias = input_tensor [2 * hidden_size :]
1334+ return v_bias .reshape (target_shape )
1335+
1336+ def reshape_vision_attn_out (input_tensor , target_shape ):
1337+ if saving_to_hf :
1338+ return input_tensor .reshape (hidden_size , hidden_size ).T
1339+ else :
1340+ return input_tensor .T .reshape (target_shape )
1341+
1342+ # Apply vision hooks
1343+ hooks ["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-patch_embed-proj-kernel" ] = reshape_conv3d_patch_embed
1344+
1345+ for i in range (n_vision_layers ):
1346+ prefix = f"params-vision_encoder-Qwen3_5MoeVisionEncoder_0-blocks_{ i } "
1347+ hooks [f"{ prefix } -attn-attn-query-kernel" ] = split_qkv_query
1348+ hooks [f"{ prefix } -attn-attn-query-bias" ] = split_qkv_bias_query
1349+ hooks [f"{ prefix } -attn-attn-key-kernel" ] = split_qkv_key
1350+ hooks [f"{ prefix } -attn-attn-key-bias" ] = split_qkv_bias_key
1351+ hooks [f"{ prefix } -attn-attn-value-kernel" ] = split_qkv_value
1352+ hooks [f"{ prefix } -attn-attn-value-bias" ] = split_qkv_bias_value
1353+ hooks [f"{ prefix } -attn-attn-out-kernel" ] = reshape_vision_attn_out
1354+ hooks [f"{ prefix } -mlp-kernel" ] = reshape_kernel_vision
1355+ hooks [f"{ prefix } -mlp_out-kernel" ] = reshape_kernel_vision
1356+
1357+ # Vision projector
1358+ hooks ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-kernel" ] = reshape_kernel_vision
1359+ hooks ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-kernel" ] = reshape_kernel_vision
1360+
12171361 return hooks
12181362
12191363
0 commit comments