Skip to content

Commit 3781883

Browse files
committed
fix ai comment
1 parent 6761862 commit 3781883

5 files changed

Lines changed: 107 additions & 27 deletions

File tree

deepmd/pt/train/validation.py

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
)
77

88
import logging
9+
import traceback
910
from dataclasses import (
1011
dataclass,
1112
)
@@ -110,6 +111,12 @@ def resolve_full_validation_start_step(
110111
def parse_validation_metric(metric: str) -> tuple[str, str]:
111112
"""Parse the configured full validation metric."""
112113
normalized_metric = normalize_full_validation_metric(metric)
114+
if normalized_metric not in METRIC_KEY_MAP:
115+
supported_metrics = ", ".join(item.upper() for item in METRIC_KEY_MAP)
116+
raise ValueError(
117+
"validating.validation_metric must be one of "
118+
f"{supported_metrics}, got {metric!r}."
119+
)
113120
return normalized_metric, METRIC_KEY_MAP[normalized_metric]
114121

115122

@@ -255,7 +262,7 @@ def __init__(
255262
self.enabled = (
256263
self.full_validation
257264
and self.start_step is not None
258-
and self.start_step < num_steps
265+
and self.start_step <= num_steps
259266
)
260267
self.step_column_width = max(len("step"), len(str(num_steps)))
261268
self._write_mode = "a" if restart_training else "w"
@@ -308,21 +315,60 @@ def run(
308315
dist.barrier()
309316

310317
result: FullValidationResult | None = None
318+
caught_exception: Exception | None = None
319+
error_message = None
311320
save_path = [None]
312321
if self.rank == 0:
313-
result = self._evaluate(display_step)
314-
save_path[0] = result.saved_best_path
322+
try:
323+
result = self._evaluate(display_step)
324+
save_path[0] = result.saved_best_path
325+
except Exception as exc:
326+
caught_exception = exc
327+
error_message = (
328+
"Full validation failed on rank 0 during evaluation:\n"
329+
f"{traceback.format_exc()}"
330+
)
331+
332+
self._raise_if_distributed_error(error_message, caught_exception)
315333

316334
if self.is_distributed:
317335
dist.broadcast_object_list(save_path, src=0)
318336

319337
if save_path[0] is not None:
320-
save_checkpoint(Path(save_path[0]), lr=lr, step=step_id)
321-
if self.rank == 0:
322-
self._prune_best_checkpoints(keep_names={Path(save_path[0]).name})
338+
try:
339+
if not self.is_distributed or self.zero_stage == 0:
340+
if self.rank == 0:
341+
save_checkpoint(Path(save_path[0]), lr=lr, step=step_id)
342+
else:
343+
save_checkpoint(Path(save_path[0]), lr=lr, step=step_id)
344+
if self.rank == 0:
345+
self._prune_best_checkpoints(keep_names={Path(save_path[0]).name})
346+
except Exception as exc:
347+
caught_exception = exc
348+
error_message = (
349+
"Full validation failed while saving the best checkpoint:\n"
350+
f"{traceback.format_exc()}"
351+
)
352+
else:
353+
error_message = None
354+
caught_exception = None
355+
356+
self._raise_if_distributed_error(error_message, caught_exception)
323357

324358
if self.rank == 0:
325-
self._log_result(result)
359+
try:
360+
self._log_result(result)
361+
except Exception as exc:
362+
caught_exception = exc
363+
error_message = (
364+
"Full validation failed while writing logs:\n"
365+
f"{traceback.format_exc()}"
366+
)
367+
else:
368+
error_message = None
369+
caught_exception = None
370+
371+
self._raise_if_distributed_error(error_message, caught_exception)
326372

327373
if self.is_distributed:
328374
dist.barrier()
@@ -367,8 +413,12 @@ def evaluate_all_systems(self) -> dict[str, float]:
367413

368414
system_metrics = []
369415
for dataset in self.validation_data.systems:
370-
assert isinstance(dataset, DeepmdDataSetForLoader)
371-
system_metrics.append(self._evaluate_system(dataset._data_system))
416+
if not isinstance(dataset, DeepmdDataSetForLoader):
417+
raise TypeError(
418+
"Full validation expects each dataset in validation_data.systems "
419+
f"to be DeepmdDataSetForLoader, got {type(dataset)!r}."
420+
)
421+
system_metrics.append(self._evaluate_system(dataset.data_system))
372422

373423
aggregated = weighted_average([metric for metric in system_metrics if metric])
374424
return {
@@ -555,6 +605,25 @@ def _initialize_best_checkpoints(self, restart_training: bool) -> None:
555605
else:
556606
self._prune_best_checkpoints()
557607

608+
def _raise_if_distributed_error(
609+
self,
610+
local_error_message: str | None,
611+
local_exception: Exception | None = None,
612+
) -> None:
613+
"""Propagate a local error to all ranks and raise consistently."""
614+
error_message = local_error_message
615+
if self.is_distributed:
616+
gathered_errors = [None] * dist.get_world_size()
617+
dist.all_gather_object(gathered_errors, local_error_message)
618+
error_message = next(
619+
(message for message in gathered_errors if message is not None), None
620+
)
621+
if error_message is None:
622+
return
623+
if local_exception is not None:
624+
raise RuntimeError(error_message) from local_exception
625+
raise RuntimeError(error_message)
626+
558627
def _log_result(self, result: FullValidationResult | None) -> None:
559628
"""Log and persist full validation results on rank 0."""
560629
assert result is not None

deepmd/pt/utils/dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ def __init__(
4949
def __len__(self) -> int:
5050
return self._data_system.nframes
5151

52+
@property
53+
def data_system(self) -> DeepmdData:
54+
"""Expose the underlying DeePMD data system."""
55+
return self._data_system
56+
5257
def __getitem__(self, index: int) -> dict[str, Any]:
5358
"""Get a frame from the selected system."""
5459
b_data = self._data_system.get_item_torch(index, max(1, NUM_WORKERS))

deepmd/utils/argcheck.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4284,10 +4284,6 @@ def validate_full_validation_config(
42844284
if float(validating.get("full_val_start", 0.0)) == 1.0:
42854285
return
42864286

4287-
if multi_task:
4288-
# Unsupported multi-task mode is rejected during trainer initialization.
4289-
return
4290-
42914287
metric = validating["validation_metric"]
42924288
if not is_valid_full_validation_metric(metric):
42934289
valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS)
@@ -4296,9 +4292,19 @@ def validate_full_validation_config(
42964292
f"{valid_metrics}, got {metric!r}."
42974293
)
42984294

4295+
if multi_task:
4296+
raise ValueError(
4297+
"validating.full_validation only supports single-task energy "
4298+
"training; multi-task training is not supported."
4299+
)
4300+
42994301
loss_params = data.get("loss", {})
4300-
if loss_params.get("type", "ener") != "ener":
4301-
return
4302+
loss_type = loss_params.get("type", "ener")
4303+
if loss_type != "ener":
4304+
raise ValueError(
4305+
"validating.full_validation only supports single-task energy "
4306+
f"training with loss.type='ener'; got loss.type={loss_type!r}."
4307+
)
43024308

43034309
if not data.get("training", {}).get("validation_data"):
43044310
raise ValueError(

source/tests/pt/test_training.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
import shutil
5+
import tempfile
56
import unittest
67
from copy import (
78
deepcopy,
@@ -761,6 +762,9 @@ def test_fitting_stat_consistency(self) -> None:
761762

762763
class TestFullValidation(unittest.TestCase):
763764
def setUp(self) -> None:
765+
self._cwd = os.getcwd()
766+
self._tmpdir = tempfile.TemporaryDirectory()
767+
os.chdir(self._tmpdir.name)
764768
input_json = str(Path(__file__).parent / "water/se_atten.json")
765769
with open(input_json) as f:
766770
self.config = json.load(f)
@@ -782,13 +786,8 @@ def setUp(self) -> None:
782786
}
783787

784788
def tearDown(self) -> None:
785-
for f in os.listdir("."):
786-
if (f.startswith("model") or f.startswith("best")) and f.endswith(".pt"):
787-
os.remove(f)
788-
if f in ["lcurve.out", "val.log", "checkpoint"]:
789-
os.remove(f)
790-
if f.startswith("stat_files"):
791-
shutil.rmtree(f)
789+
os.chdir(self._cwd)
790+
self._tmpdir.cleanup()
792791

793792
@patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems")
794793
def test_full_validation_rotates_best_checkpoint(self, mocked_eval) -> None:
@@ -843,12 +842,10 @@ def test_full_validation_rejects_multitask(self) -> None:
843842
"full_val_file": "val.log",
844843
"full_val_start": 0.0,
845844
}
846-
config["model"], shared_links = preprocess_shared_params(config["model"])
845+
config["model"], _ = preprocess_shared_params(config["model"])
847846
config = update_deepmd_input(config, warning=False)
848-
config = normalize(config, multi_task=True)
849-
850847
with self.assertRaisesRegex(ValueError, "multi-task"):
851-
get_trainer(config, shared_links=shared_links)
848+
normalize(config, multi_task=True)
852849

853850

854851
if __name__ == "__main__":

source/tests/pt/test_validation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
)
88

99
import torch
10+
from dargs.dargs import (
11+
ArgumentValueError,
12+
)
1013

1114
from deepmd.pt.train.validation import (
1215
FullValidator,
@@ -135,5 +138,5 @@ def test_normalize_rejects_zero_prefactor_metric(self) -> None:
135138
def test_normalize_rejects_invalid_metric(self) -> None:
136139
config = _make_single_task_config()
137140
config["validating"]["validation_metric"] = "X:MAE"
138-
with self.assertRaisesRegex(Exception, "validation_metric"):
141+
with self.assertRaisesRegex(ArgumentValueError, "validation_metric"):
139142
normalize(config)

0 commit comments

Comments
 (0)