1717import gc
1818import math
1919import os
20+ import re
2021from dataclasses import fields
2122from functools import partial
2223
8586)
8687
8788
89+ def frozen_param_expect_mtp (model , config ):
90+ def extract_layer_idx (text ):
91+ match = re .search (r"model.layers.(-?\d+\.?\d*)" , text )
92+ if match :
93+ num_str = match .group (1 )
94+ # 区分整数和小数返回(避免123.0这种冗余浮点数)
95+ if "." in num_str :
96+ return float (num_str )
97+ else :
98+ return int (num_str )
99+ return None
100+
101+ # not sure can work on all model
102+ jackpot = set (range (config .num_hidden_layers , config .num_hidden_layers + config .num_nextn_predict_layers ))
103+ for name , param in model .state_dict ().items ():
104+ layer_idx = extract_layer_idx (name )
105+ is_mtp = layer_idx in jackpot
106+ if not is_mtp :
107+ param .stop_gradient = True
108+ else :
109+ param .stop_gradient = False
110+
111+
88112def create_pretrained_dataset (training_args , data_args , model_args ):
89113 assert data_args .input_dir is not None and len (data_args .input_dir .split ()) > 1
90114
@@ -653,6 +677,8 @@ def neft_post_hook(module, input, output):
653677 callbacks += [FP8QuantWeightCallback ()]
654678
655679 print ("callbacks:" , callbacks , flush = True )
680+ # print("ddd: ", model); exit()
681+
656682 trainer = SFTTrainer (
657683 model = model ,
658684 args = training_args ,
@@ -665,6 +691,7 @@ def neft_post_hook(module, input, output):
665691 data_args = data_args ,
666692 callbacks = callbacks ,
667693 )
694+ frozen_param_expect_mtp (model , model_config )
668695 trainable_parameters = [p for p in model .parameters () if not p .stop_gradient ]
669696 trainer .set_optimizer_grouped_parameters (trainable_parameters )
670697
0 commit comments