Skip to content

Commit 6158d9c

Browse files
author
Han Wang
committed
fix: address CodeQL findings in PR #5397
Replace assert with if/raise ValueError for user-facing config validation (data_stat_protect, finetune branch/head checks). Wrap train() in try/finally for destroy_process_group cleanup. Add parents=True, exist_ok=True to stat_file mkdir. Add strict=True to zip() calls. Fix minor test issues.
1 parent 80c714c commit 6158d9c

4 files changed

Lines changed: 38 additions & 30 deletions

File tree

deepmd/pt_expt/entrypoints/main.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_trainer(
9191
with h5py.File(stat_file_path, "w"):
9292
pass
9393
else:
94-
Path(stat_file_path).mkdir()
94+
Path(stat_file_path).mkdir(parents=True, exist_ok=True)
9595
stat_file_path = DPPath(stat_file_path, "a")
9696
else:
9797
# Multi-task: build per-task data systems
@@ -143,7 +143,7 @@ def get_trainer(
143143
with h5py.File(_sf, "w"):
144144
pass
145145
else:
146-
Path(_sf).mkdir(parents=True)
146+
Path(_sf).mkdir(parents=True, exist_ok=True)
147147
stat_file_path[model_key] = DPPath(_sf, "a")
148148
else:
149149
stat_file_path[model_key] = None
@@ -290,18 +290,19 @@ def train(
290290
if os.environ.get("LOCAL_RANK") is not None:
291291
dist.init_process_group(backend="cuda:nccl,cpu:gloo")
292292

293-
trainer = get_trainer(
294-
config,
295-
init_model,
296-
restart,
297-
finetune_model=finetune,
298-
finetune_links=finetune_links,
299-
shared_links=shared_links,
300-
)
301-
trainer.run()
302-
303-
if dist.is_available() and dist.is_initialized():
304-
dist.destroy_process_group()
293+
try:
294+
trainer = get_trainer(
295+
config,
296+
init_model,
297+
restart,
298+
finetune_model=finetune,
299+
finetune_links=finetune_links,
300+
shared_links=shared_links,
301+
)
302+
trainer.run()
303+
finally:
304+
if dist.is_available() and dist.is_initialized():
305+
dist.destroy_process_group()
305306

306307

307308
def freeze(

deepmd/pt_expt/train/training.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -611,13 +611,16 @@ def _make_sample(
611611
for ii in model_params["model_dict"]
612612
]
613613
)
614-
assert np.allclose(_data_stat_protect, _data_stat_protect[0]), (
615-
"Model key 'data_stat_protect' must be the same in each branch when multitask!"
616-
)
614+
if not np.allclose(_data_stat_protect, _data_stat_protect[0]):
615+
raise ValueError(
616+
"Model key 'data_stat_protect' must be the same in each branch when multitask!"
617+
)
617618
self.wrapper.share_params(
618619
shared_links,
619620
resume=(resuming and not self._finetune_update_stat) or self.rank != 0,
620-
model_key_prob_map=dict(zip(self.model_keys, self.model_prob)),
621+
model_key_prob_map=dict(
622+
zip(self.model_keys, self.model_prob, strict=True)
623+
),
621624
data_stat_protect=_data_stat_protect[0],
622625
)
623626

@@ -825,7 +828,9 @@ def _make_sample(
825828
self._unwrapped.share_params(
826829
shared_links,
827830
resume=True,
828-
model_key_prob_map=dict(zip(self.model_keys, self.model_prob)),
831+
model_key_prob_map=dict(
832+
zip(self.model_keys, self.model_prob, strict=True)
833+
),
829834
)
830835

831836
if optimizer_state_dict is not None:

deepmd/pt_expt/utils/finetune.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ def get_finetune_rules(
105105
finetune_links["Default"] = finetune_rule
106106
else:
107107
# Multi-task target — mirrors PT's logic
108-
assert model_branch == "", (
109-
"Multi-task fine-tuning does not support command-line branches chosen!"
110-
"Please define the 'finetune_head' in each model params!"
111-
)
108+
if model_branch != "":
109+
raise ValueError(
110+
"Multi-task fine-tuning does not support command-line branches chosen! "
111+
"Please define the 'finetune_head' in each model params!"
112+
)
112113
if not finetune_from_multi_task:
113114
pretrained_keys = ["Default"]
114115
else:
@@ -120,10 +121,11 @@ def get_finetune_rules(
120121
and model_config["model_dict"][model_key]["finetune_head"] != "RANDOM"
121122
):
122123
pretrained_key = model_config["model_dict"][model_key]["finetune_head"]
123-
assert pretrained_key in pretrained_keys, (
124-
f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model!"
125-
f"Available heads are: {list(pretrained_keys)}"
126-
)
124+
if pretrained_key not in pretrained_keys:
125+
raise ValueError(
126+
f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model! "
127+
f"Available heads are: {list(pretrained_keys)}"
128+
)
127129
model_branch_from = pretrained_key
128130
elif (
129131
"finetune_head" not in model_config["model_dict"][model_key]

source/tests/pt_expt/test_multitask.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ def test_multitask_finetune_no_change_model_params(self) -> None:
936936
# Also set trainable=False to verify it's preserved
937937
ft_config_true["model"]["shared_dict"]["my_descriptor"]["trainable"] = False
938938
ft_config_true["model"], _ = preprocess_shared_params(ft_config_true["model"])
939-
model_config_true, finetune_links_true = get_finetune_rules(
939+
model_config_true, _finetune_links_true = get_finetune_rules(
940940
ckpt_path, deepcopy(ft_config_true["model"]), change_model_params=True
941941
)
942942

@@ -1164,7 +1164,7 @@ def test_multitask_restart(self) -> None:
11641164
def test_multitask_freeze(self) -> None:
11651165
"""Train, then freeze with --head and verify.
11661166
1167-
Only runs for se_e2_a descriptor to avoid redundant slow freeze tests.
1167+
Only runs for dpa3 descriptor to avoid redundant slow freeze tests.
11681168
"""
11691169
if self.descriptor.get("type") != "dpa3":
11701170
return
@@ -1948,7 +1948,7 @@ def test_gradient_accumulation(self) -> None:
19481948

19491949
# Verify descriptor params are aliased (share_params)
19501950
mt_desc_2 = mt_trainer.wrapper.model["model_2"].atomic_model.descriptor
1951-
for (n1, p1), (n2, p2) in zip(
1951+
for (n1, p1), (_n2, p2) in zip(
19521952
mt_desc.named_parameters(), mt_desc_2.named_parameters(), strict=True
19531953
):
19541954
assert p1.data_ptr() == p2.data_ptr(), (

0 commit comments

Comments
 (0)