Skip to content

Commit e875944

Browse files
committed
feat: add full validation
1 parent efc27cf commit e875944

File tree

13 files changed

+1976
-225
lines changed

13 files changed

+1976
-225
lines changed

deepmd/entrypoints/test.py

Lines changed: 338 additions & 213 deletions
Large diffs are not rendered by default.

deepmd/jax/utils/auto_batch_size.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@ def __init__(
2424
self,
2525
initial_batch_size: int = 1024,
2626
factor: float = 2.0,
27+
*,
28+
silent: bool = False,
2729
) -> None:
2830
super().__init__(
2931
initial_batch_size=initial_batch_size,
3032
factor=factor,
33+
silent=silent,
3134
)
3235

3336
def is_gpu_available(self) -> bool:

deepmd/pd/utils/auto_batch_size.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ def __init__(
2222
self,
2323
initial_batch_size: int = 1024,
2424
factor: float = 2.0,
25+
*,
26+
silent: bool = False,
2527
) -> None:
2628
super().__init__(
2729
initial_batch_size=initial_batch_size,
2830
factor=factor,
31+
silent=silent,
2932
)
3033

3134
def is_gpu_available(self) -> bool:

deepmd/pt/train/training.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
KFOptimizerWrapper,
5555
LKFOptimizer,
5656
)
57+
from deepmd.pt.train.validation import (
58+
FullValidator,
59+
resolve_full_validation_start_step,
60+
)
5761
from deepmd.pt.train.wrapper import (
5862
ModelWrapper,
5963
)
@@ -857,11 +861,88 @@ def single_model_finetune(
857861
self.enable_profiler = training_params.get("enable_profiler", False)
858862
self.profiling = training_params.get("profiling", False)
859863
self.profiling_file = training_params.get("profiling_file", "timeline.json")
864+
validating_params = config.get("validating") or {}
865+
self.full_validator = self._create_full_validator(
866+
validating_params=validating_params,
867+
validation_data=validation_data,
868+
)
860869

861870
# Log model parameter count
862871
if self.rank == 0:
863872
self._log_parameter_count()
864873

874+
def _create_full_validator(
875+
self,
876+
*,
877+
validating_params: dict[str, Any],
878+
validation_data: DpLoaderSet | None,
879+
) -> FullValidator | None:
880+
"""Create the runtime full validator when it is active."""
881+
if not self._is_full_validation_requested(validating_params):
882+
return None
883+
self._raise_if_full_validation_unsupported(validation_data)
884+
if validation_data is None:
885+
raise RuntimeError(
886+
"validation_data must be available after full validation checks."
887+
)
888+
return FullValidator(
889+
validating_params=validating_params,
890+
validation_data=validation_data,
891+
model=self.model,
892+
train_infos=self._get_inner_module().train_infos,
893+
num_steps=self.num_steps,
894+
rank=self.rank,
895+
zero_stage=self.zero_stage,
896+
restart_training=self.restart_training,
897+
)
898+
899+
def _is_full_validation_requested(self, validating_params: dict[str, Any]) -> bool:
900+
"""Check whether full validation can trigger during this training run."""
901+
if not validating_params.get("full_validation", False):
902+
return False
903+
start_step = resolve_full_validation_start_step(
904+
validating_params.get("full_val_start", 0.0),
905+
self.num_steps,
906+
)
907+
return start_step is not None and start_step <= self.num_steps
908+
909+
def _raise_if_full_validation_unsupported(
910+
self,
911+
validation_data: DpLoaderSet | None,
912+
) -> None:
913+
"""Validate runtime full validation constraints."""
914+
if self.multi_task:
915+
raise ValueError(
916+
"validating.full_validation only supports single-task energy "
917+
"training; multi-task training is not supported."
918+
)
919+
920+
has_spin = getattr(self.model, "has_spin", False)
921+
if callable(has_spin):
922+
has_spin = has_spin()
923+
if has_spin or isinstance(self.loss, EnergySpinLoss):
924+
raise ValueError(
925+
"validating.full_validation only supports single-task energy "
926+
"training; spin-energy training is not supported."
927+
)
928+
929+
if not isinstance(self.loss, EnergyStdLoss):
930+
raise ValueError(
931+
"validating.full_validation only supports single-task energy training."
932+
)
933+
934+
if validation_data is None:
935+
raise ValueError(
936+
"validating.full_validation requires `training.validation_data` "
937+
"to be configured."
938+
)
939+
940+
if self.zero_stage >= 2:
941+
raise ValueError(
942+
"validating.full_validation only supports single-task energy "
943+
"training with training.zero_stage < 2."
944+
)
945+
865946
@staticmethod
866947
def _count_parameters(model: torch.nn.Module) -> tuple[int, int]:
867948
"""
@@ -1363,6 +1444,14 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
13631444
fout, display_step_id, cur_lr, train_results, valid_results
13641445
)
13651446

1447+
if self.full_validator is not None:
1448+
self.full_validator.run(
1449+
step_id=_step_id,
1450+
display_step=display_step_id,
1451+
lr=cur_lr,
1452+
save_checkpoint=self.save_model,
1453+
)
1454+
13661455
if (
13671456
(
13681457
(display_step_id) % self.save_freq == 0

0 commit comments

Comments
 (0)