@@ -99,7 +99,7 @@ def extract_layer_idx(text):
9999 return None
100100
101101 # not sure can work on all model
102- jackpot = set (range (config .num_hidden_layers , config .num_hidden_layers + config .num_nextn_predict_layers ))
102+ jackpot = set (range (config .num_hidden_layers , config .num_hidden_layers + config .mtp_num_layers ))
103103 for name , param in model .state_dict ().items ():
104104 layer_idx = extract_layer_idx (name )
105105 is_mtp = layer_idx in jackpot
@@ -288,6 +288,22 @@ def run_sft(
288288 LlmMetaConfig .set_llm_config (model_config , training_args )
289289 model_config .use_fast_layer_norm = model_args .use_fast_layer_norm
290290
291+ # autoregressive mtp training
292+ activate_autoregressive_mtp_training = False
293+ if model_config .mtp_num_layers > 1 :
294+ activate_autoregressive_mtp_training = True
295+ tmp = model_config .mtp_num_layers
296+ model_config .mtp_num_layers = model_config .num_nextn_predict_layers
297+ model_config .num_nextn_predict_layers = tmp
298+
299+ tmp = training_args .mtp_num_layers
300+ training_args .mtp_num_layers = training_args .num_nextn_predict_layers
301+ training_args .num_nextn_predict_layers = tmp
302+
303+ logger .info (
304+ f"MTP args changing for autoregressive mtp training, mtp_num_layers: { model_config .mtp_num_layers } , num_nextn_predict_layers: { model_config .num_nextn_predict_layers } !!"
305+ )
306+
291307 # Config for model using dropout, such as GPT.
292308 if hasattr (model_config , "hidden_dropout_prob" ):
293309 model_config .hidden_dropout_prob = finetuning_args .hidden_dropout_prob
@@ -699,10 +715,14 @@ def fetch_and_serialize(generator, dtype):
699715 data_args = data_args ,
700716 callbacks = callbacks ,
701717 )
718+ freeze_param_except_mtp (model , model_config )
719+
720+ if activate_autoregressive_mtp_training :
721+ # activate autoregressive mtp training
722+ freeze_param_except_mtp (model , model_config )
702723 trainable_parameters = [
703724 p for p in model .parameters () if not p .stop_gradient or ("quantization_linear" in p .name and "w_1" in p .name )
704725 ]
705- freeze_param_except_mtp (model , model_config )
706726 trainer .set_optimizer_grouped_parameters (trainable_parameters )
707727
708728 # Train
0 commit comments