@@ -1042,6 +1042,68 @@ def QWEN3_5_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fals
10421042 ): f"model.language_model.layers.{ i } .mlp.experts.gate_up_proj" ,
10431043 }
10441044 )
1045+
1046+ # Vision mapping for Qwen3.5
1047+ if maxtext_config .use_multimodal and "vision_config" in config :
1048+ vision_config = config ["vision_config" ]
1049+ n_vision_layers = vision_config ["depth" ]
1050+
1051+ # Vision patch embedding
1052+ mapping ["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-patch_embed-proj-kernel" ] = (
1053+ "model.visual.patch_embed.proj.weight"
1054+ )
1055+ mapping ["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-patch_embed-proj-bias" ] = (
1056+ "model.visual.patch_embed.proj.bias"
1057+ )
1058+
1059+ # Vision positional embedding
1060+ mapping ["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-pos_embed_interpolate-pos_embed" ] = (
1061+ "model.visual.pos_embed.weight"
1062+ )
1063+
1064+ # Vision blocks
1065+ for i in range (n_vision_layers ):
1066+ prefix = f"params-vision_encoder-Qwen3_5MoeVisionEncoder_0-blocks_{ i } "
1067+ hf_prefix = f"model.visual.blocks.{ i } "
1068+
1069+ # Layer norms
1070+ mapping [f"{ prefix } -ln1-scale" ] = f"{ hf_prefix } .norm1.weight"
1071+ mapping [f"{ prefix } -ln1-bias" ] = f"{ hf_prefix } .norm1.bias"
1072+ mapping [f"{ prefix } -ln2-scale" ] = f"{ hf_prefix } .norm2.weight"
1073+ mapping [f"{ prefix } -ln2-bias" ] = f"{ hf_prefix } .norm2.bias"
1074+
1075+ # Attention
1076+ mapping [f"{ prefix } -attn-attn-query-kernel" ] = f"{ hf_prefix } .attn.qkv.weight"
1077+ mapping [f"{ prefix } -attn-attn-query-bias" ] = f"{ hf_prefix } .attn.qkv.bias"
1078+ mapping [f"{ prefix } -attn-attn-key-kernel" ] = f"{ hf_prefix } .attn.qkv.weight"
1079+ mapping [f"{ prefix } -attn-attn-key-bias" ] = f"{ hf_prefix } .attn.qkv.bias"
1080+ mapping [f"{ prefix } -attn-attn-value-kernel" ] = f"{ hf_prefix } .attn.qkv.weight"
1081+ mapping [f"{ prefix } -attn-attn-value-bias" ] = f"{ hf_prefix } .attn.qkv.bias"
1082+ mapping [f"{ prefix } -attn-attn-out-kernel" ] = f"{ hf_prefix } .attn.proj.weight"
1083+ mapping [f"{ prefix } -attn-attn-out-bias" ] = f"{ hf_prefix } .attn.proj.bias"
1084+
1085+ # MLP
1086+ mapping [f"{ prefix } -mlp-kernel" ] = f"{ hf_prefix } .mlp.linear_fc1.weight"
1087+ mapping [f"{ prefix } -mlp-bias" ] = f"{ hf_prefix } .mlp.linear_fc1.bias"
1088+ mapping [f"{ prefix } -mlp_out-kernel" ] = f"{ hf_prefix } .mlp.linear_fc2.weight"
1089+ mapping [f"{ prefix } -mlp_out-bias" ] = f"{ hf_prefix } .mlp.linear_fc2.bias"
1090+
1091+ # Vision projector (final merger)
1092+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-ln_q-scale" ] = "model.visual.merger.norm.weight"
1093+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-ln_q-bias" ] = "model.visual.merger.norm.bias"
1094+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-kernel" ] = (
1095+ "model.visual.merger.linear_fc1.weight"
1096+ )
1097+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-bias" ] = (
1098+ "model.visual.merger.linear_fc1.bias"
1099+ )
1100+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-kernel" ] = (
1101+ "model.visual.merger.linear_fc2.weight"
1102+ )
1103+ mapping ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-bias" ] = (
1104+ "model.visual.merger.linear_fc2.bias"
1105+ )
1106+
10451107 return mapping
10461108
10471109
@@ -1214,6 +1276,92 @@ def concat_ba_and_transpose(input_tensor, target_shape=None):
12141276 hooks [(f"{ mlp_prefix } -routed_experts-wi_0" , f"{ mlp_prefix } -routed_experts-wi_1" )] = process_wi_0_wi_1
12151277 hooks [f"{ mlp_prefix } -routed_experts-wo" ] = transpose_expert
12161278
1279+ # Vision hooks for Qwen3.5
1280+ vision_config = config .get ("vision_config" , None )
1281+ if vision_config and maxtext_config .use_multimodal :
1282+ n_vision_layers = vision_config ["depth" ]
1283+ hidden_size = vision_config ["hidden_size" ]
1284+
1285+ def reshape_kernel_vision (input_tensor , target_shape ):
1286+ if saving_to_hf :
1287+ flipped_target_shape = np .flip (np .array (target_shape ))
1288+ return input_tensor .reshape (flipped_target_shape ).T
1289+ else :
1290+ return input_tensor .T .reshape (target_shape )
1291+
1292+ def reshape_conv3d_patch_embed (input_tensor , target_shape ):
1293+ if saving_to_hf :
1294+ return input_tensor .transpose (4 , 3 , 0 , 1 , 2 )
1295+ else :
1296+ return input_tensor .transpose (2 , 3 , 4 , 1 , 0 )
1297+
1298+ def split_qkv_query (input_tensor , target_shape ):
1299+ if saving_to_hf :
1300+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1301+ else :
1302+ q_weight = input_tensor [:hidden_size , :]
1303+ return q_weight .T .reshape (target_shape )
1304+
1305+ def split_qkv_key (input_tensor , target_shape ):
1306+ if saving_to_hf :
1307+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1308+ else :
1309+ k_weight = input_tensor [hidden_size : 2 * hidden_size , :]
1310+ return k_weight .T .reshape (target_shape )
1311+
1312+ def split_qkv_value (input_tensor , target_shape ):
1313+ if saving_to_hf :
1314+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1315+ else :
1316+ v_weight = input_tensor [2 * hidden_size :, :]
1317+ return v_weight .T .reshape (target_shape )
1318+
1319+ def split_qkv_bias_query (input_tensor , target_shape ):
1320+ if saving_to_hf :
1321+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1322+ else :
1323+ q_bias = input_tensor [:hidden_size ]
1324+ return q_bias .reshape (target_shape )
1325+
1326+ def split_qkv_bias_key (input_tensor , target_shape ):
1327+ if saving_to_hf :
1328+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1329+ else :
1330+ k_bias = input_tensor [hidden_size : 2 * hidden_size ]
1331+ return k_bias .reshape (target_shape )
1332+
1333+ def split_qkv_bias_value (input_tensor , target_shape ):
1334+ if saving_to_hf :
1335+ raise NotImplementedError ("Use fusion hook for MaxText->HF" )
1336+ else :
1337+ v_bias = input_tensor [2 * hidden_size :]
1338+ return v_bias .reshape (target_shape )
1339+
1340+ def reshape_vision_attn_out (input_tensor , target_shape ):
1341+ if saving_to_hf :
1342+ return input_tensor .reshape (hidden_size , hidden_size ).T
1343+ else :
1344+ return input_tensor .T .reshape (target_shape )
1345+
1346+ # Apply vision hooks
1347+ hooks ["params-vision_encoder-Qwen3_5MoeVisionEncoder_0-patch_embed-proj-kernel" ] = reshape_conv3d_patch_embed
1348+
1349+ for i in range (n_vision_layers ):
1350+ prefix = f"params-vision_encoder-Qwen3_5MoeVisionEncoder_0-blocks_{ i } "
1351+ hooks [f"{ prefix } -attn-attn-query-kernel" ] = split_qkv_query
1352+ hooks [f"{ prefix } -attn-attn-query-bias" ] = split_qkv_bias_query
1353+ hooks [f"{ prefix } -attn-attn-key-kernel" ] = split_qkv_key
1354+ hooks [f"{ prefix } -attn-attn-key-bias" ] = split_qkv_bias_key
1355+ hooks [f"{ prefix } -attn-attn-value-kernel" ] = split_qkv_value
1356+ hooks [f"{ prefix } -attn-attn-value-bias" ] = split_qkv_bias_value
1357+ hooks [f"{ prefix } -attn-attn-out-kernel" ] = reshape_vision_attn_out
1358+ hooks [f"{ prefix } -mlp-kernel" ] = reshape_kernel_vision
1359+ hooks [f"{ prefix } -mlp_out-kernel" ] = reshape_kernel_vision
1360+
1361+ # Vision projector
1362+ hooks ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_0-kernel" ] = reshape_kernel_vision
1363+ hooks ["params-vision_encoder-Qwen3_5MoeVisionProjector_0-merger-mlp_2-kernel" ] = reshape_kernel_vision
1364+
12171365 return hooks
12181366
12191367
0 commit comments