Skip to content

Commit 0f50be0

Browse files
committed
feat: add full validation
1 parent 034e613 commit 0f50be0

5 files changed

Lines changed: 1090 additions & 0 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
KFOptimizerWrapper,
5555
LKFOptimizer,
5656
)
57+
from deepmd.pt.train.validation import (
58+
FullValidator,
59+
resolve_full_validation_start_step,
60+
)
5761
from deepmd.pt.train.wrapper import (
5862
ModelWrapper,
5963
)
@@ -857,6 +861,57 @@ def single_model_finetune(
857861
self.enable_profiler = training_params.get("enable_profiler", False)
858862
self.profiling = training_params.get("profiling", False)
859863
self.profiling_file = training_params.get("profiling_file", "timeline.json")
864+
self.full_validator = None
865+
866+
validating_params = config.get("validating") or {}
867+
validation_start_step = resolve_full_validation_start_step(
868+
validating_params.get("full_val_start", 0.0),
869+
self.num_steps,
870+
)
871+
full_validation_requested = (
872+
bool(validating_params.get("full_validation", False))
873+
and validation_start_step is not None
874+
and validation_start_step < self.num_steps
875+
)
876+
if full_validation_requested:
877+
if self.multi_task:
878+
raise ValueError(
879+
"validating.full_validation only supports single-task energy "
880+
"training; multi-task training is not supported."
881+
)
882+
has_spin = getattr(self.model, "has_spin", False)
883+
if callable(has_spin):
884+
has_spin = has_spin()
885+
if has_spin or isinstance(self.loss, EnergySpinLoss):
886+
raise ValueError(
887+
"validating.full_validation only supports single-task energy "
888+
"training; spin-energy training is not supported."
889+
)
890+
if not isinstance(self.loss, EnergyStdLoss):
891+
raise ValueError(
892+
"validating.full_validation only supports single-task energy "
893+
"training."
894+
)
895+
if validation_data is None:
896+
raise ValueError(
897+
"validating.full_validation requires `training.validation_data` "
898+
"to be configured."
899+
)
900+
if self.zero_stage >= 2:
901+
raise ValueError(
902+
"validating.full_validation only supports single-task energy "
903+
"training with training.zero_stage < 2."
904+
)
905+
self.full_validator = FullValidator(
906+
validating_params=validating_params,
907+
validation_data=validation_data,
908+
model=self.model,
909+
train_infos=self._get_inner_module().train_infos,
910+
num_steps=self.num_steps,
911+
rank=self.rank,
912+
zero_stage=self.zero_stage,
913+
restart_training=self.restart_training,
914+
)
860915

861916
# Log model parameter count
862917
if self.rank == 0:
@@ -1363,6 +1418,14 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13631418
fout, display_step_id, cur_lr, train_results, valid_results
13641419
)
13651420

1421+
if self.full_validator is not None:
1422+
self.full_validator.run(
1423+
step_id=_step_id,
1424+
display_step=display_step_id,
1425+
lr=cur_lr,
1426+
save_checkpoint=self.save_model,
1427+
)
1428+
13661429
if (
13671430
(
13681431
(display_step_id) % self.save_freq == 0

0 commit comments

Comments
 (0)