Skip to content

Commit ef6e4ff

Browse files
Copilotnjzjz
andcommitted
feat(finetune): enhance nlayer warnings to support DPA2/DPA3 repformer.nlayers
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 24b37b0 commit ef6e4ff

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

deepmd/pd/utils/finetune.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def _warn_descriptor_config_differences(
5858
f"with the pretrained model's configuration:\n" + "\n".join(differences)
5959
)
6060

61-
# Special warning for nlayer changes
61+
# Special warning for nlayer changes (check both top-level and nested)
6262
if (
6363
"nlayer" in input_descriptor
6464
and "nlayer" in pretrained_descriptor
@@ -70,6 +70,20 @@ def _warn_descriptor_config_differences(
7070
f"model architecture and performance."
7171
)
7272

73+
# Check for nested nlayers in repformer (DPA2/DPA3 models)
74+
input_repformer = input_descriptor.get("repformer", {})
75+
pretrained_repformer = pretrained_descriptor.get("repformer", {})
76+
if (
77+
"nlayers" in input_repformer
78+
and "nlayers" in pretrained_repformer
79+
and input_repformer["nlayers"] != pretrained_repformer["nlayers"]
80+
):
81+
log.warning(
82+
f"IMPORTANT: repformer.nlayers changed from {input_repformer['nlayers']} to "
83+
f"{pretrained_repformer['nlayers']}. This may significantly affect "
84+
f"model architecture and performance."
85+
)
86+
7387

7488
def get_finetune_rule_single(
7589
_single_param_target,

deepmd/pt/utils/finetune.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _warn_descriptor_config_differences(
6161
f"with the pretrained model's configuration:\n" + "\n".join(differences)
6262
)
6363

64-
# Special warning for nlayer changes
64+
# Special warning for nlayer changes (check both top-level and nested)
6565
if (
6666
"nlayer" in input_descriptor
6767
and "nlayer" in pretrained_descriptor
@@ -73,6 +73,20 @@ def _warn_descriptor_config_differences(
7373
f"model architecture and performance."
7474
)
7575

76+
# Check for nested nlayers in repformer (DPA2/DPA3 models)
77+
input_repformer = input_descriptor.get("repformer", {})
78+
pretrained_repformer = pretrained_descriptor.get("repformer", {})
79+
if (
80+
"nlayers" in input_repformer
81+
and "nlayers" in pretrained_repformer
82+
and input_repformer["nlayers"] != pretrained_repformer["nlayers"]
83+
):
84+
log.warning(
85+
f"IMPORTANT: repformer.nlayers changed from {input_repformer['nlayers']} to "
86+
f"{pretrained_repformer['nlayers']}. This may significantly affect "
87+
f"model architecture and performance."
88+
)
89+
7690

7791
def get_finetune_rule_single(
7892
_single_param_target,

0 commit comments

Comments
 (0)