Skip to content

Commit 9300fac

Browse files
committed
ignore ema val when full_val is false
1 parent 96b2817 commit 9300fac

3 files changed

Lines changed: 33 additions & 7 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,9 @@ def _create_ema_full_validator(
10271027
validation_data: DpLoaderSet | None,
10281028
) -> FullValidator | None:
10291029
"""Create the runtime EMA full validator when it is active."""
1030-
if not self._is_validation_requested(validating_params, "ema_full_validation"):
1030+
if not self._is_validation_requested(
1031+
validating_params, "full_validation"
1032+
) or not validating_params.get("ema_full_validation", False):
10311033
return None
10321034
self._raise_if_full_validation_unsupported(validation_data)
10331035
if self.model_ema is None:

deepmd/utils/argcheck.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4257,10 +4257,11 @@ def validating_args() -> Argument:
42574257
)
42584258
doc_ema_full_validation = (
42594259
"Whether to additionally run the same full validation flow on the "
4260-
"EMA-smoothed model. This reuses the existing full validation schedule, "
4261-
"metric, start step, and best-checkpoint settings, writes results to an "
4262-
"EMA-specific validation log such as `val_ema.log`, and saves EMA best "
4263-
"checkpoints with a `best_ema.ckpt` prefix. Requires "
4260+
"EMA-smoothed model when `validating.full_validation=true`. This reuses "
4261+
"the existing full validation schedule, metric, start step, and "
4262+
"best-checkpoint settings, writes results to an EMA-specific validation "
4263+
"log such as `val_ema.log`, and saves EMA best checkpoints with a "
4264+
"`best_ema.ckpt` prefix. Requires "
42644265
"`training.enable_ema=true`."
42654266
)
42664267
doc_max_best_ckpt = (
@@ -4374,7 +4375,7 @@ def validate_full_validation_config(
43744375
training_params = data.get("training", {}) or {}
43754376
full_validation_enabled = bool(validating.get("full_validation", False))
43764377
ema_full_validation_enabled = bool(validating.get("ema_full_validation", False))
4377-
if not full_validation_enabled and not ema_full_validation_enabled:
4378+
if not full_validation_enabled:
43784379
return
43794380
if float(validating.get("full_val_start", 0.0)) == 1.0:
43804381
return

source/tests/pt/test_training.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,17 +1063,22 @@ def test_restart_restores_ema_state(self) -> None:
10631063
@patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems")
10641064
def test_ema_full_validation_writes_separate_outputs(self, mocked_eval) -> None:
10651065
mocked_eval.side_effect = [
1066+
{"mae_e_per_atom": 10.0},
10661067
{"mae_e_per_atom": 1.0},
1068+
{"mae_e_per_atom": 10.0},
10671069
{"mae_e_per_atom": 0.5},
1070+
{"mae_e_per_atom": 10.0},
10681071
{"mae_e_per_atom": 0.75},
1072+
{"mae_e_per_atom": 10.0},
10691073
{"mae_e_per_atom": 0.25},
10701074
]
10711075
config = deepcopy(self.config)
1076+
config["validating"]["full_validation"] = True
10721077
config["validating"]["ema_full_validation"] = True
10731078
trainer = get_trainer(config)
10741079
trainer.run()
10751080

1076-
self.assertFalse(Path("val.log").exists())
1081+
self.assertTrue(Path("val.log").exists())
10771082
self.assertTrue(Path("val_ema.log").exists())
10781083
self.assertTrue(Path("best_ema.ckpt-4.t-1.pt").exists())
10791084
self.assertFalse(Path("best_ema.ckpt-1.t-1.pt").exists())
@@ -1085,6 +1090,24 @@ def test_ema_full_validation_writes_separate_outputs(self, mocked_eval) -> None:
10851090
[{"metric": 0.25, "step": 4}],
10861091
)
10871092

1093+
@TRAINING_TEST_TIMEOUT
1094+
@patch("deepmd.pt.train.validation.FullValidator.evaluate_all_systems")
1095+
def test_ema_full_validation_ignored_without_full_validation(
1096+
self, mocked_eval
1097+
) -> None:
1098+
config = deepcopy(self.config)
1099+
config["training"]["enable_ema"] = False
1100+
config["validating"]["full_validation"] = False
1101+
config["validating"]["ema_full_validation"] = True
1102+
trainer = get_trainer(config)
1103+
trainer.run()
1104+
1105+
mocked_eval.assert_not_called()
1106+
self.assertFalse(Path("val.log").exists())
1107+
self.assertFalse(Path("val_ema.log").exists())
1108+
self.assertIsNone(trainer.model_ema)
1109+
self.assertIsNone(trainer.ema_full_validator)
1110+
10881111

10891112
if __name__ == "__main__":
10901113
unittest.main()

0 commit comments

Comments
 (0)