Skip to content

Commit cb3abdb

Browse files
committed
remove comment.
1 parent f3fa729 commit cb3abdb

1 file changed

Lines changed: 14 additions & 15 deletions

File tree

paddleformers/cli/train/sft/workflow.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -304,20 +304,6 @@ def run_sft(
304304
LlmMetaConfig.set_llm_config(model_config, training_args)
305305
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm
306306

307-
# autoregressive mtp training
308-
if model_config.mtp_num_layers > 1:
309-
tmp = model_config.mtp_num_layers
310-
model_config.mtp_num_layers = model_config.num_nextn_predict_layers
311-
model_config.num_nextn_predict_layers = tmp
312-
313-
tmp = training_args.mtp_num_layers
314-
training_args.mtp_num_layers = training_args.num_nextn_predict_layers
315-
training_args.num_nextn_predict_layers = tmp
316-
317-
logger.info(
318-
f"MTP args changing for autoregressive mtp training, mtp_num_layers: {model_config.mtp_num_layers}, num_nextn_predict_layers: {model_config.num_nextn_predict_layers}!!"
319-
)
320-
321307
# Config for model using dropout, such as GPT.
322308
if hasattr(model_config, "hidden_dropout_prob"):
323309
model_config.hidden_dropout_prob = finetuning_args.hidden_dropout_prob
@@ -407,6 +393,20 @@ def neft_post_hook(module, input, output):
407393
else:
408394
raise NotImplementedError("Only support neftune for model with get_input_embeddings")
409395

396+
# autoregressive mtp training
397+
if model_config.mtp_num_layers > 1:
398+
tmp = model_config.mtp_num_layers
399+
model_config.mtp_num_layers = model_config.num_nextn_predict_layers
400+
model_config.num_nextn_predict_layers = tmp
401+
402+
tmp = training_args.mtp_num_layers
403+
training_args.mtp_num_layers = training_args.num_nextn_predict_layers
404+
training_args.num_nextn_predict_layers = tmp
405+
406+
logger.info(
407+
f"MTP args changing for autoregressive mtp training, mtp_num_layers: {model_config.mtp_num_layers}, num_nextn_predict_layers: {model_config.num_nextn_predict_layers}!!"
408+
)
409+
410410
runtime_timer = RuntimeTimer("Creating SFT MapDataset")
411411

412412
# Load tokenizer & processor & dataset
@@ -723,7 +723,6 @@ def fetch_and_serialize(generator, dtype):
723723
callbacks += [FP8QuantWeightCallback()]
724724

725725
print("callbacks:", callbacks, flush=True)
726-
# print("ddd: ", model); exit()
727726

728727
trainer = SFTTrainer(
729728
model=model,

0 commit comments

Comments
 (0)