Skip to content

Commit c3936e5

Browse files
committed
remove comment.
1 parent b9eea88 commit c3936e5

1 file changed

Lines changed: 14 additions & 15 deletions

File tree

paddleformers/cli/train/sft/workflow.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -290,20 +290,6 @@ def run_sft(
290290
LlmMetaConfig.set_llm_config(model_config, training_args)
291291
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm
292292

293-
# autoregressive mtp training
294-
if model_config.mtp_num_layers > 1:
295-
tmp = model_config.mtp_num_layers
296-
model_config.mtp_num_layers = model_config.num_nextn_predict_layers
297-
model_config.num_nextn_predict_layers = tmp
298-
299-
tmp = training_args.mtp_num_layers
300-
training_args.mtp_num_layers = training_args.num_nextn_predict_layers
301-
training_args.num_nextn_predict_layers = tmp
302-
303-
logger.info(
304-
f"MTP args changing for autoregressive mtp training, mtp_num_layers: {model_config.mtp_num_layers}, num_nextn_predict_layers: {model_config.num_nextn_predict_layers}!!"
305-
)
306-
307293
# Config for model using dropout, such as GPT.
308294
if hasattr(model_config, "hidden_dropout_prob"):
309295
model_config.hidden_dropout_prob = finetuning_args.hidden_dropout_prob
@@ -390,6 +376,20 @@ def neft_post_hook(module, input, output):
390376
else:
391377
raise NotImplementedError("Only support neftune for model with get_input_embeddings")
392378

379+
# autoregressive mtp training
380+
if model_config.mtp_num_layers > 1:
381+
tmp = model_config.mtp_num_layers
382+
model_config.mtp_num_layers = model_config.num_nextn_predict_layers
383+
model_config.num_nextn_predict_layers = tmp
384+
385+
tmp = training_args.mtp_num_layers
386+
training_args.mtp_num_layers = training_args.num_nextn_predict_layers
387+
training_args.num_nextn_predict_layers = tmp
388+
389+
logger.info(
390+
f"MTP args changing for autoregressive mtp training, mtp_num_layers: {model_config.mtp_num_layers}, num_nextn_predict_layers: {model_config.num_nextn_predict_layers}!!"
391+
)
392+
393393
runtime_timer = RuntimeTimer("Creating SFT MapDataset")
394394

395395
# Load tokenizer & processor & dataset
@@ -701,7 +701,6 @@ def fetch_and_serialize(generator, dtype):
701701
callbacks += [FP8QuantWeightCallback()]
702702

703703
print("callbacks:", callbacks, flush=True)
704-
# print("ddd: ", model); exit()
705704

706705
trainer = SFTTrainer(
707706
model=model,

0 commit comments

Comments
 (0)