Skip to content

Commit 3104a30

Browse files
committed
soft coding.
1 parent 8bd895d commit 3104a30

3 files changed

Lines changed: 29 additions & 3 deletions

File tree

paddleformers/cli/train/sft/workflow.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def extract_layer_idx(text):
9999
return None
100100

101101
# not sure can work on all model
102-
jackpot = set(range(config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers))
102+
jackpot = set(range(config.num_hidden_layers, config.num_hidden_layers + config.mtp_num_layers))
103103
for name, param in model.state_dict().items():
104104
layer_idx = extract_layer_idx(name)
105105
is_mtp = layer_idx in jackpot
@@ -288,6 +288,22 @@ def run_sft(
288288
LlmMetaConfig.set_llm_config(model_config, training_args)
289289
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm
290290

291+
# autoregressive mtp training
292+
activate_autoregressive_mtp_training = False
293+
if model_config.mtp_num_layers > 1:
294+
activate_autoregressive_mtp_training = True
295+
tmp = model_config.mtp_num_layers
296+
model_config.mtp_num_layers = model_config.num_nextn_predict_layers
297+
model_config.num_nextn_predict_layers = tmp
298+
299+
tmp = training_args.mtp_num_layers
300+
training_args.mtp_num_layers = training_args.num_nextn_predict_layers
301+
training_args.num_nextn_predict_layers = tmp
302+
303+
logger.info(
304+
f"MTP args changing for autoregressive mtp training, mtp_num_layers: {model_config.mtp_num_layers}, num_nextn_predict_layers: {model_config.num_nextn_predict_layers}!!"
305+
)
306+
291307
# Config for model using dropout, such as GPT.
292308
if hasattr(model_config, "hidden_dropout_prob"):
293309
model_config.hidden_dropout_prob = finetuning_args.hidden_dropout_prob
@@ -699,10 +715,14 @@ def fetch_and_serialize(generator, dtype):
699715
data_args=data_args,
700716
callbacks=callbacks,
701717
)
718+
freeze_param_except_mtp(model, model_config)
719+
720+
if activate_autoregressive_mtp_training:
721+
# activate autoregressive mtp training
722+
freeze_param_except_mtp(model, model_config)
702723
trainable_parameters = [
703724
p for p in model.parameters() if not p.stop_gradient or ("quantization_linear" in p.name and "w_1" in p.name)
704725
]
705-
freeze_param_except_mtp(model, model_config)
706726
trainer.set_optimizer_grouped_parameters(trainable_parameters)
707727

708728
# Train

paddleformers/trainer/training_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,10 @@ class TrainingArguments:
12121212
metadata={"help": "pre allocate memory size GB"},
12131213
)
12141214
num_nextn_predict_layers: int = field(default=0, metadata={"help": "Number of nextn predict layers."})
1215+
mtp_distillation_loss: bool = field(default=False, metadata={"help": "Whether to use distillation MTP loss."})
1216+
mtp_num_layers: int = field(
1217+
default=0, metadata={"help": "Whether to use Autoregressive MTP Training, activate if > 1."}
1218+
)
12151219
profile: bool = field(default=False, metadata={"help": "Enable nsys profiling."})
12161220
profile_step_start: int = field(default=10, metadata={"help": "Step to start nsys profiling."})
12171221
profile_step_end: int = field(default=12, metadata={"help": "Step to end nsys profiling."})

paddleformers/transformers/configuration_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,9 @@ class LlmMetaConfig:
413413
]
414414

415415
mtp_attributes = [
416-
("num_nextn_predict_layers", int, 0, "Number of nextn predict layers."),
416+
# ("num_nextn_predict_layers", int, 0, "Number of nextn predict layers."),
417+
("mtp_distillation_loss", bool, False, "Whether to use distillation MTP loss."),
418+
("mtp_num_layers", int, 0, "Whether to use Autoregressive MTP Training, activate if > 1."),
417419
(
418420
"mtp_loss_scaling_factor",
419421
float,

0 commit comments

Comments
 (0)