Skip to content

Commit b0b9c37

Browse files
committed
fixup
1 parent 84b374b commit b0b9c37

1 file changed

Lines changed: 38 additions & 4 deletions

File tree

source/tests/pt/test_training.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -707,11 +707,9 @@ class TestModelChangeOutBiasFittingStat(unittest.TestCase):
707707
"""
708708

709709
def test_fitting_stat_consistency(self) -> None:
710+
import deepmd.pt.train.training as training_module
710711
from deepmd.pt.model.model import get_model as get_model_pt
711712
from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT
712-
from deepmd.pt.train.training import (
713-
model_change_out_bias,
714-
)
715713
from deepmd.pt.utils.utils import to_numpy_array as torch_to_numpy
716714
from deepmd.pt.utils.utils import to_torch_tensor as numpy_to_torch
717715
from deepmd.utils.argcheck import model_args as model_args_fn
@@ -785,7 +783,7 @@ def test_fitting_stat_consistency(self) -> None:
785783

786784
# Model B: use the NEW code path via model_change_out_bias
787785
sample_func = lambda: merged # noqa: E731
788-
model_change_out_bias(model_b, sample_func, "set-by-statistic")
786+
training_module.model_change_out_bias(model_b, sample_func, "set-by-statistic")
789787

790788
# Compare out_bias
791789
bias_a = torch_to_numpy(model_a.get_out_bias())
@@ -1019,6 +1017,42 @@ def test_ema_checkpoint_cleanup_removes_future_steps(self) -> None:
10191017
self.assertFalse(Path(f"{ema_prefix}-999.pt").exists())
10201018
self.assertTrue(Path(f"{ema_prefix}-1.pt").exists())
10211019

1020+
@TRAINING_TEST_TIMEOUT
1021+
@patch("deepmd.pt.train.training.model_change_out_bias")
1022+
def test_ema_checkpoint_keeps_changed_out_bias(self, mocked_change_out_bias) -> None:
1023+
def change_out_bias(model, sample_func, _bias_adjust_mode):
1024+
model.set_out_bias(model.get_out_bias() + 1.0)
1025+
return model
1026+
1027+
mocked_change_out_bias.side_effect = change_out_bias
1028+
config = deepcopy(self.config)
1029+
config["training"]["numb_steps"] = 1
1030+
config["training"]["change_bias_after_training"] = True
1031+
trainer = get_trainer(config)
1032+
trainer.run()
1033+
1034+
regular_checkpoint = torch.load(
1035+
trainer.latest_model, map_location="cpu", weights_only=True
1036+
)
1037+
ema_checkpoint = torch.load(
1038+
trainer.latest_ema_model, map_location="cpu", weights_only=True
1039+
)
1040+
regular_out_bias = {
1041+
key: value
1042+
for key, value in regular_checkpoint["model"].items()
1043+
if key.endswith("out_bias")
1044+
}
1045+
ema_out_bias = {
1046+
key: value
1047+
for key, value in ema_checkpoint["model"].items()
1048+
if key.endswith("out_bias")
1049+
}
1050+
1051+
self.assertTrue(regular_out_bias)
1052+
self.assertEqual(regular_out_bias.keys(), ema_out_bias.keys())
1053+
for key, regular_value in regular_out_bias.items():
1054+
torch.testing.assert_close(regular_value, ema_out_bias[key])
1055+
10221056
def test_ema_rejects_zero_stage_2_during_normalization(self) -> None:
10231057
config = deepcopy(self.config)
10241058
config["training"]["zero_stage"] = 2

0 commit comments

Comments
 (0)