Skip to content

Commit f216f89

Browse files
committed
causal mtp [hard coding].
1 parent edfc4ca commit f216f89

3 files changed

Lines changed: 33 additions & 0 deletions

File tree

paddleformers/cli/train/sft/workflow.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import gc
1818
import math
1919
import os
20+
import re
2021
from dataclasses import fields
2122
from functools import partial
2223

@@ -85,6 +86,29 @@
8586
)
8687

8788

89+
def frozen_param_expect_mtp(model, config):
90+
def extract_layer_idx(text):
91+
match = re.search(r"model.layers.(-?\d+\.?\d*)", text)
92+
if match:
93+
num_str = match.group(1)
94+
# 区分整数和小数返回(避免123.0这种冗余浮点数)
95+
if "." in num_str:
96+
return float(num_str)
97+
else:
98+
return int(num_str)
99+
return None
100+
101+
# not sure can work on all model
102+
jackpot = set(range(config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers))
103+
for name, param in model.state_dict().items():
104+
layer_idx = extract_layer_idx(name)
105+
is_mtp = layer_idx in jackpot
106+
if not is_mtp:
107+
param.stop_gradient = True
108+
else:
109+
param.stop_gradient = False
110+
111+
88112
def create_pretrained_dataset(training_args, data_args, model_args):
89113
assert data_args.input_dir is not None and len(data_args.input_dir.split()) > 1
90114

@@ -653,6 +677,8 @@ def neft_post_hook(module, input, output):
653677
callbacks += [FP8QuantWeightCallback()]
654678

655679
print("callbacks:", callbacks, flush=True)
680+
# print("ddd: ", model); exit()
681+
656682
trainer = SFTTrainer(
657683
model=model,
658684
args=training_args,
@@ -665,6 +691,7 @@ def neft_post_hook(module, input, output):
665691
data_args=data_args,
666692
callbacks=callbacks,
667693
)
694+
frozen_param_expect_mtp(model, model_config)
668695
trainable_parameters = [p for p in model.parameters() if not p.stop_gradient]
669696
trainer.set_optimizer_grouped_parameters(trainable_parameters)
670697

paddleformers/trainer/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3147,6 +3147,8 @@ def new_global_norm_func(
31473147
global_norm_var_not_dist,
31483148
*args,
31493149
):
3150+
print("WE DO NOT CAL GRAD NORM FOR NOW")
3151+
return
31503152
if len(args) > 0:
31513153
global_norm_func(global_norm_var_dist, global_norm_var_not_dist, *args)
31523154
global_norm_var_dist_moe, global_norm_var_not_dist_moe = args
@@ -3157,6 +3159,8 @@ def new_global_norm_func(
31573159
+ global_norm_var_not_dist_moe
31583160
)
31593161
else:
3162+
print("global_norm_var_dist: ", global_norm_var_dist)
3163+
print("global_norm_var_not_dist: ", global_norm_var_not_dist)
31603164
global_norm_func(global_norm_var_dist, global_norm_var_not_dist)
31613165
global_norm_var_fp32 = paddle.sqrt(global_norm_var_dist + global_norm_var_not_dist)
31623166
training_logs["global_norm"] = global_norm_var_fp32.item()

paddleformers/transformers/glm4_moe/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,7 @@ def _gen_aoa_config(cls, config: Glm4MoeConfig):
911911
]
912912

913913
num_nextn_predict_layers = config.num_nextn_predict_layers if config.num_nextn_predict_layers else 0
914+
num_nextn_predict_layers = 1
914915

915916
for layer_idx in reversed(range(num_hidden_layers, num_hidden_layers + num_nextn_predict_layers)):
916917
layer_idx_offset = layer_idx + num_head_empty_layers
@@ -1057,6 +1058,7 @@ def _gen_inv_aoa_config(cls, config: Glm4MoeConfig):
10571058
]
10581059

10591060
num_nextn_predict_layers = config.num_nextn_predict_layers if config.num_nextn_predict_layers else 0
1061+
num_nextn_predict_layers = 1
10601062

10611063
for layer_idx in reversed(range(num_hidden_layers, num_hidden_layers + num_nextn_predict_layers)):
10621064
layer_idx_offset = layer_idx + num_head_empty_layers

0 commit comments

Comments
 (0)