Skip to content

Commit 54fa343

Browse files
Copilotnjzjz
andcommitted
feat(finetune): add warnings for descriptor config mismatches without --use-pretrain-script
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 1d6643b commit 54fa343

2 files changed

Lines changed: 174 additions & 0 deletions

File tree

deepmd/pd/train/training.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,60 @@
8282
log = logging.getLogger(__name__)
8383

8484

85+
def _warn_configuration_mismatch_during_finetune(
86+
input_descriptor: dict,
87+
pretrained_descriptor: dict,
88+
model_branch: str = "Default",
89+
) -> None:
90+
"""
91+
Warn about configuration mismatches between input descriptor and pretrained model
92+
when fine-tuning without --use-pretrain-script option.
93+
94+
This function warns when configurations differ and state_dict initialization
95+
will only pick relevant keys from the pretrained model (e.g., first 6 layers
96+
from a 16-layer model).
97+
98+
Parameters
99+
----------
100+
input_descriptor : dict
101+
Descriptor configuration from input.json
102+
pretrained_descriptor : dict
103+
Descriptor configuration from pretrained model
104+
model_branch : str
105+
Model branch name for logging context
106+
"""
107+
if input_descriptor == pretrained_descriptor:
108+
return
109+
110+
# Collect differences
111+
differences = []
112+
113+
# Check for keys that differ in values
114+
for key in input_descriptor:
115+
if key in pretrained_descriptor:
116+
if input_descriptor[key] != pretrained_descriptor[key]:
117+
differences.append(
118+
f" {key}: {input_descriptor[key]} (input) vs {pretrained_descriptor[key]} (pretrained)"
119+
)
120+
else:
121+
differences.append(f" {key}: {input_descriptor[key]} (input only)")
122+
123+
# Check for keys only in pretrained model
124+
for key in pretrained_descriptor:
125+
if key not in input_descriptor:
126+
differences.append(
127+
f" {key}: {pretrained_descriptor[key]} (pretrained only)"
128+
)
129+
130+
if differences:
131+
log.warning(
132+
f"Descriptor configuration mismatch detected between input.json and pretrained model "
133+
f"(branch '{model_branch}'). State dict initialization will only use compatible parameters "
134+
f"from the pretrained model. Mismatched configuration:\n"
135+
+ "\n".join(differences)
136+
)
137+
138+
85139
class Trainer:
86140
def __init__(
87141
self,
@@ -117,6 +171,8 @@ def __init__(
117171
training_params = config["training"]
118172
self.multi_task = "model_dict" in model_params
119173
self.finetune_links = finetune_links
174+
# Store model params for finetune warning comparisons
175+
self.model_params = model_params
120176
self.finetune_update_stat = False
121177
self.model_keys = (
122178
list(model_params["model_dict"]) if self.multi_task else ["Default"]
@@ -512,6 +568,37 @@ def collect_single_finetune_params(
512568
)
513569

514570
# collect model params from the pretrained model
571+
# First check for configuration mismatches and warn if needed
572+
pretrained_model_params = state_dict["_extra_state"]["model_params"]
573+
for model_key in self.model_keys:
574+
finetune_rule_single = self.finetune_links[model_key]
575+
_model_key_from = finetune_rule_single.get_model_branch()
576+
577+
# Get current model descriptor config
578+
if self.multi_task:
579+
current_descriptor = self.model_params["model_dict"][
580+
model_key
581+
].get("descriptor", {})
582+
else:
583+
current_descriptor = self.model_params.get("descriptor", {})
584+
585+
# Get pretrained model descriptor config
586+
if "model_dict" in pretrained_model_params:
587+
pretrained_descriptor = pretrained_model_params[
588+
"model_dict"
589+
][_model_key_from].get("descriptor", {})
590+
else:
591+
pretrained_descriptor = pretrained_model_params.get(
592+
"descriptor", {}
593+
)
594+
595+
# Warn about configuration mismatches
596+
_warn_configuration_mismatch_during_finetune(
597+
current_descriptor,
598+
pretrained_descriptor,
599+
_model_key_from,
600+
)
601+
515602
for model_key in self.model_keys:
516603
finetune_rule_single = self.finetune_links[model_key]
517604
collect_single_finetune_params(

deepmd/pt/train/training.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,60 @@
8888
log = logging.getLogger(__name__)
8989

9090

91+
def _warn_configuration_mismatch_during_finetune(
92+
input_descriptor: dict,
93+
pretrained_descriptor: dict,
94+
model_branch: str = "Default",
95+
) -> None:
96+
"""
97+
Warn about configuration mismatches between input descriptor and pretrained model
98+
when fine-tuning without --use-pretrain-script option.
99+
100+
This function warns when configurations differ and state_dict initialization
101+
will only pick relevant keys from the pretrained model (e.g., first 6 layers
102+
from a 16-layer model).
103+
104+
Parameters
105+
----------
106+
input_descriptor : dict
107+
Descriptor configuration from input.json
108+
pretrained_descriptor : dict
109+
Descriptor configuration from pretrained model
110+
model_branch : str
111+
Model branch name for logging context
112+
"""
113+
if input_descriptor == pretrained_descriptor:
114+
return
115+
116+
# Collect differences
117+
differences = []
118+
119+
# Check for keys that differ in values
120+
for key in input_descriptor:
121+
if key in pretrained_descriptor:
122+
if input_descriptor[key] != pretrained_descriptor[key]:
123+
differences.append(
124+
f" {key}: {input_descriptor[key]} (input) vs {pretrained_descriptor[key]} (pretrained)"
125+
)
126+
else:
127+
differences.append(f" {key}: {input_descriptor[key]} (input only)")
128+
129+
# Check for keys only in pretrained model
130+
for key in pretrained_descriptor:
131+
if key not in input_descriptor:
132+
differences.append(
133+
f" {key}: {pretrained_descriptor[key]} (pretrained only)"
134+
)
135+
136+
if differences:
137+
log.warning(
138+
f"Descriptor configuration mismatch detected between input.json and pretrained model "
139+
f"(branch '{model_branch}'). State dict initialization will only use compatible parameters "
140+
f"from the pretrained model. Mismatched configuration:\n"
141+
+ "\n".join(differences)
142+
)
143+
144+
91145
class Trainer:
92146
def __init__(
93147
self,
@@ -122,6 +176,8 @@ def __init__(
122176
training_params = config["training"]
123177
self.multi_task = "model_dict" in model_params
124178
self.finetune_links = finetune_links
179+
# Store model params for finetune warning comparisons
180+
self.model_params = model_params
125181
self.finetune_update_stat = False
126182
self.model_keys = (
127183
list(model_params["model_dict"]) if self.multi_task else ["Default"]
@@ -541,6 +597,37 @@ def collect_single_finetune_params(
541597
)
542598

543599
# collect model params from the pretrained model
600+
# First check for configuration mismatches and warn if needed
601+
pretrained_model_params = state_dict["_extra_state"]["model_params"]
602+
for model_key in self.model_keys:
603+
finetune_rule_single = self.finetune_links[model_key]
604+
_model_key_from = finetune_rule_single.get_model_branch()
605+
606+
# Get current model descriptor config
607+
if self.multi_task:
608+
current_descriptor = self.model_params["model_dict"][
609+
model_key
610+
].get("descriptor", {})
611+
else:
612+
current_descriptor = self.model_params.get("descriptor", {})
613+
614+
# Get pretrained model descriptor config
615+
if "model_dict" in pretrained_model_params:
616+
pretrained_descriptor = pretrained_model_params[
617+
"model_dict"
618+
][_model_key_from].get("descriptor", {})
619+
else:
620+
pretrained_descriptor = pretrained_model_params.get(
621+
"descriptor", {}
622+
)
623+
624+
# Warn about configuration mismatches
625+
_warn_configuration_mismatch_during_finetune(
626+
current_descriptor,
627+
pretrained_descriptor,
628+
_model_key_from,
629+
)
630+
544631
for model_key in self.model_keys:
545632
finetune_rule_single = self.finetune_links[model_key]
546633
collect_single_finetune_params(

0 commit comments

Comments
 (0)