@@ -304,20 +304,6 @@ def run_sft(
304304 LlmMetaConfig .set_llm_config (model_config , training_args )
305305 model_config .use_fast_layer_norm = model_args .use_fast_layer_norm
306306
307- # autoregressive mtp training
308- if model_config .mtp_num_layers > 1 :
309- tmp = model_config .mtp_num_layers
310- model_config .mtp_num_layers = model_config .num_nextn_predict_layers
311- model_config .num_nextn_predict_layers = tmp
312-
313- tmp = training_args .mtp_num_layers
314- training_args .mtp_num_layers = training_args .num_nextn_predict_layers
315- training_args .num_nextn_predict_layers = tmp
316-
317- logger .info (
318- f"MTP args changing for autoregressive mtp training, mtp_num_layers: { model_config .mtp_num_layers } , num_nextn_predict_layers: { model_config .num_nextn_predict_layers } !!"
319- )
320-
321307 # Config for model using dropout, such as GPT.
322308 if hasattr (model_config , "hidden_dropout_prob" ):
323309 model_config .hidden_dropout_prob = finetuning_args .hidden_dropout_prob
@@ -407,6 +393,20 @@ def neft_post_hook(module, input, output):
407393 else :
408394 raise NotImplementedError ("Only support neftune for model with get_input_embeddings" )
409395
396+ # autoregressive mtp training
397+ if model_config .mtp_num_layers > 1 :
398+ tmp = model_config .mtp_num_layers
399+ model_config .mtp_num_layers = model_config .num_nextn_predict_layers
400+ model_config .num_nextn_predict_layers = tmp
401+
402+ tmp = training_args .mtp_num_layers
403+ training_args .mtp_num_layers = training_args .num_nextn_predict_layers
404+ training_args .num_nextn_predict_layers = tmp
405+
406+ logger .info (
407+ f"MTP args changing for autoregressive mtp training, mtp_num_layers: { model_config .mtp_num_layers } , num_nextn_predict_layers: { model_config .num_nextn_predict_layers } !!"
408+ )
409+
410410 runtime_timer = RuntimeTimer ("Creating SFT MapDataset" )
411411
412412 # Load tokenizer & processor & dataset
@@ -723,7 +723,6 @@ def fetch_and_serialize(generator, dtype):
723723 callbacks += [FP8QuantWeightCallback ()]
724724
725725 print ("callbacks:" , callbacks , flush = True )
726- # print("ddd: ", model); exit()
727726
728727 trainer = SFTTrainer (
729728 model = model ,
0 commit comments