@@ -707,11 +707,9 @@ class TestModelChangeOutBiasFittingStat(unittest.TestCase):
707707 """
708708
709709 def test_fitting_stat_consistency (self ) -> None :
710+ import deepmd .pt .train .training as training_module
710711 from deepmd .pt .model .model import get_model as get_model_pt
711712 from deepmd .pt .model .model .ener_model import EnergyModel as EnergyModelPT
712- from deepmd .pt .train .training import (
713- model_change_out_bias ,
714- )
715713 from deepmd .pt .utils .utils import to_numpy_array as torch_to_numpy
716714 from deepmd .pt .utils .utils import to_torch_tensor as numpy_to_torch
717715 from deepmd .utils .argcheck import model_args as model_args_fn
@@ -785,7 +783,7 @@ def test_fitting_stat_consistency(self) -> None:
785783
786784 # Model B: use the NEW code path via model_change_out_bias
787785 sample_func = lambda : merged # noqa: E731
788- model_change_out_bias (model_b , sample_func , "set-by-statistic" )
786+ training_module . model_change_out_bias (model_b , sample_func , "set-by-statistic" )
789787
790788 # Compare out_bias
791789 bias_a = torch_to_numpy (model_a .get_out_bias ())
@@ -1019,6 +1017,42 @@ def test_ema_checkpoint_cleanup_removes_future_steps(self) -> None:
10191017 self .assertFalse (Path (f"{ ema_prefix } -999.pt" ).exists ())
10201018 self .assertTrue (Path (f"{ ema_prefix } -1.pt" ).exists ())
10211019
1020+ @TRAINING_TEST_TIMEOUT
1021+ @patch ("deepmd.pt.train.training.model_change_out_bias" )
1022+ def test_ema_checkpoint_keeps_changed_out_bias (self , mocked_change_out_bias ) -> None :
1023+ def change_out_bias (model , sample_func , _bias_adjust_mode ):
1024+ model .set_out_bias (model .get_out_bias () + 1.0 )
1025+ return model
1026+
1027+ mocked_change_out_bias .side_effect = change_out_bias
1028+ config = deepcopy (self .config )
1029+ config ["training" ]["numb_steps" ] = 1
1030+ config ["training" ]["change_bias_after_training" ] = True
1031+ trainer = get_trainer (config )
1032+ trainer .run ()
1033+
1034+ regular_checkpoint = torch .load (
1035+ trainer .latest_model , map_location = "cpu" , weights_only = True
1036+ )
1037+ ema_checkpoint = torch .load (
1038+ trainer .latest_ema_model , map_location = "cpu" , weights_only = True
1039+ )
1040+ regular_out_bias = {
1041+ key : value
1042+ for key , value in regular_checkpoint ["model" ].items ()
1043+ if key .endswith ("out_bias" )
1044+ }
1045+ ema_out_bias = {
1046+ key : value
1047+ for key , value in ema_checkpoint ["model" ].items ()
1048+ if key .endswith ("out_bias" )
1049+ }
1050+
1051+ self .assertTrue (regular_out_bias )
1052+ self .assertEqual (regular_out_bias .keys (), ema_out_bias .keys ())
1053+ for key , regular_value in regular_out_bias .items ():
1054+ torch .testing .assert_close (regular_value , ema_out_bias [key ])
1055+
10221056 def test_ema_rejects_zero_stage_2_during_normalization (self ) -> None :
10231057 config = deepcopy (self .config )
10241058 config ["training" ]["zero_stage" ] = 2
0 commit comments