@@ -1063,17 +1063,22 @@ def test_restart_restores_ema_state(self) -> None:
10631063 @patch ("deepmd.pt.train.validation.FullValidator.evaluate_all_systems" )
10641064 def test_ema_full_validation_writes_separate_outputs (self , mocked_eval ) -> None :
10651065 mocked_eval .side_effect = [
1066+ {"mae_e_per_atom" : 10.0 },
10661067 {"mae_e_per_atom" : 1.0 },
1068+ {"mae_e_per_atom" : 10.0 },
10671069 {"mae_e_per_atom" : 0.5 },
1070+ {"mae_e_per_atom" : 10.0 },
10681071 {"mae_e_per_atom" : 0.75 },
1072+ {"mae_e_per_atom" : 10.0 },
10691073 {"mae_e_per_atom" : 0.25 },
10701074 ]
10711075 config = deepcopy (self .config )
1076+ config ["validating" ]["full_validation" ] = True
10721077 config ["validating" ]["ema_full_validation" ] = True
10731078 trainer = get_trainer (config )
10741079 trainer .run ()
10751080
1076- self .assertFalse (Path ("val.log" ).exists ())
1081+ self .assertTrue (Path ("val.log" ).exists ())
10771082 self .assertTrue (Path ("val_ema.log" ).exists ())
10781083 self .assertTrue (Path ("best_ema.ckpt-4.t-1.pt" ).exists ())
10791084 self .assertFalse (Path ("best_ema.ckpt-1.t-1.pt" ).exists ())
@@ -1085,6 +1090,24 @@ def test_ema_full_validation_writes_separate_outputs(self, mocked_eval) -> None:
10851090 [{"metric" : 0.25 , "step" : 4 }],
10861091 )
10871092
1093+ @TRAINING_TEST_TIMEOUT
1094+ @patch ("deepmd.pt.train.validation.FullValidator.evaluate_all_systems" )
1095+ def test_ema_full_validation_ignored_without_full_validation (
1096+ self , mocked_eval
1097+ ) -> None :
1098+ config = deepcopy (self .config )
1099+ config ["training" ]["enable_ema" ] = False
1100+ config ["validating" ]["full_validation" ] = False
1101+ config ["validating" ]["ema_full_validation" ] = True
1102+ trainer = get_trainer (config )
1103+ trainer .run ()
1104+
1105+ mocked_eval .assert_not_called ()
1106+ self .assertFalse (Path ("val.log" ).exists ())
1107+ self .assertFalse (Path ("val_ema.log" ).exists ())
1108+ self .assertIsNone (trainer .model_ema )
1109+ self .assertIsNone (trainer .ema_full_validator )
1110+
10881111
10891112if __name__ == "__main__" :
10901113 unittest .main ()
0 commit comments