@@ -290,20 +290,6 @@ def run_sft(
290290 LlmMetaConfig .set_llm_config (model_config , training_args )
291291 model_config .use_fast_layer_norm = model_args .use_fast_layer_norm
292292
293- # autoregressive mtp training
294- if model_config .mtp_num_layers > 1 :
295- tmp = model_config .mtp_num_layers
296- model_config .mtp_num_layers = model_config .num_nextn_predict_layers
297- model_config .num_nextn_predict_layers = tmp
298-
299- tmp = training_args .mtp_num_layers
300- training_args .mtp_num_layers = training_args .num_nextn_predict_layers
301- training_args .num_nextn_predict_layers = tmp
302-
303- logger .info (
304- f"MTP args changing for autoregressive mtp training, mtp_num_layers: { model_config .mtp_num_layers } , num_nextn_predict_layers: { model_config .num_nextn_predict_layers } !!"
305- )
306-
307293 # Config for model using dropout, such as GPT.
308294 if hasattr (model_config , "hidden_dropout_prob" ):
309295 model_config .hidden_dropout_prob = finetuning_args .hidden_dropout_prob
@@ -390,6 +376,20 @@ def neft_post_hook(module, input, output):
390376 else :
391377 raise NotImplementedError ("Only support neftune for model with get_input_embeddings" )
392378
379+ # autoregressive mtp training
380+ if model_config .mtp_num_layers > 1 :
381+ tmp = model_config .mtp_num_layers
382+ model_config .mtp_num_layers = model_config .num_nextn_predict_layers
383+ model_config .num_nextn_predict_layers = tmp
384+
385+ tmp = training_args .mtp_num_layers
386+ training_args .mtp_num_layers = training_args .num_nextn_predict_layers
387+ training_args .num_nextn_predict_layers = tmp
388+
389+ logger .info (
390+ f"MTP args changing for autoregressive mtp training, mtp_num_layers: { model_config .mtp_num_layers } , num_nextn_predict_layers: { model_config .num_nextn_predict_layers } !!"
391+ )
392+
393393 runtime_timer = RuntimeTimer ("Creating SFT MapDataset" )
394394
395395 # Load tokenizer & processor & dataset
@@ -701,7 +701,6 @@ def fetch_and_serialize(generator, dtype):
701701 callbacks += [FP8QuantWeightCallback ()]
702702
703703 print ("callbacks:" , callbacks , flush = True )
704- # print("ddd: ", model); exit()
705704
706705 trainer = SFTTrainer (
707706 model = model ,
0 commit comments