Skip to content

Commit 6a60acd

Browse files
Copilotnjzjz
andcommitted
Addressing PR comments
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 272e087 commit 6a60acd

2 files changed

Lines changed: 127 additions & 0 deletions

File tree

.coverage

108 KB
Binary file not shown.

source/tests/tf/test_change_bias.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from pathlib import (
55
Path,
66
)
7+
from unittest.mock import (
8+
MagicMock,
9+
patch,
10+
)
711

812
from deepmd.tf.entrypoints.change_bias import (
913
change_bias,
@@ -89,6 +93,129 @@ def test_change_bias_user_defined_not_implemented(self):
8993
"User-defined bias setting is not yet implemented", str(cm.exception)
9094
)
9195

96+
def test_change_bias_successful_execution(self):
97+
"""Test successful bias changing execution path."""
98+
# Create fake checkpoint directory with required files
99+
fake_checkpoint_dir = self.temp_path / "checkpoint"
100+
fake_checkpoint_dir.mkdir()
101+
(fake_checkpoint_dir / "checkpoint").write_text("fake checkpoint content")
102+
(fake_checkpoint_dir / "input.json").write_text('{"model": {}}')
103+
104+
# Create fake data system
105+
fake_data_dir = self.temp_path / "data_system"
106+
fake_data_dir.mkdir()
107+
fake_set_dir = fake_data_dir / "set.000"
108+
fake_set_dir.mkdir()
109+
110+
# Import the module properly
111+
import sys
112+
113+
change_bias_module = sys.modules["deepmd.tf.entrypoints.change_bias"]
114+
115+
with (
116+
patch.object(
117+
change_bias_module, "expand_sys_str", return_value=[str(fake_data_dir)]
118+
),
119+
patch.object(change_bias_module, "j_loader", return_value={"model": {}}),
120+
patch.object(
121+
change_bias_module, "update_deepmd_input", return_value={"model": {}}
122+
),
123+
patch.object(change_bias_module, "normalize", return_value={"model": {}}),
124+
patch.object(change_bias_module, "DeepmdDataSystem") as mock_data_system,
125+
patch.object(change_bias_module, "DPTrainer") as mock_trainer_class,
126+
patch.object(change_bias_module, "shutil"),
127+
):
128+
# Mock the data system
129+
mock_data_instance = MagicMock()
130+
mock_data_instance.get_type_map.return_value = ["H", "O"]
131+
mock_data_system.return_value = mock_data_instance
132+
133+
# Mock the trainer
134+
mock_trainer_instance = MagicMock()
135+
mock_model = MagicMock()
136+
mock_model.get_type_map.return_value = ["H", "O"]
137+
mock_trainer_instance.model = mock_model
138+
mock_trainer_instance._change_energy_bias = MagicMock()
139+
mock_trainer_instance.save_checkpoint = MagicMock()
140+
mock_trainer_class.return_value = mock_trainer_instance
141+
142+
# Call change_bias function
143+
change_bias(
144+
INPUT=str(fake_checkpoint_dir),
145+
mode="change",
146+
system=str(fake_data_dir),
147+
output=str(self.temp_path / "output"),
148+
)
149+
150+
# Verify that the trainer's change_energy_bias was called
151+
mock_trainer_instance._change_energy_bias.assert_called_once()
152+
153+
def test_change_bias_with_data_type_map(self):
154+
"""Test bias changing when data system has its own type_map."""
155+
# Create fake checkpoint directory with required files
156+
fake_checkpoint_dir = self.temp_path / "checkpoint"
157+
fake_checkpoint_dir.mkdir()
158+
(fake_checkpoint_dir / "checkpoint").write_text("fake checkpoint content")
159+
(fake_checkpoint_dir / "input.json").write_text('{"model": {}}')
160+
161+
# Create fake data system
162+
fake_data_dir = self.temp_path / "data_system"
163+
fake_data_dir.mkdir()
164+
fake_set_dir = fake_data_dir / "set.000"
165+
fake_set_dir.mkdir()
166+
167+
# Import the module properly
168+
import sys
169+
170+
change_bias_module = sys.modules["deepmd.tf.entrypoints.change_bias"]
171+
172+
with (
173+
patch.object(
174+
change_bias_module, "expand_sys_str", return_value=[str(fake_data_dir)]
175+
),
176+
patch.object(change_bias_module, "j_loader", return_value={"model": {}}),
177+
patch.object(
178+
change_bias_module, "update_deepmd_input", return_value={"model": {}}
179+
),
180+
patch.object(change_bias_module, "normalize", return_value={"model": {}}),
181+
patch.object(change_bias_module, "DeepmdDataSystem") as mock_data_system,
182+
patch.object(change_bias_module, "DPTrainer") as mock_trainer_class,
183+
patch.object(change_bias_module, "shutil"),
184+
):
185+
# Mock the data system with type_map
186+
mock_data_instance = MagicMock()
187+
mock_data_instance.get_type_map.return_value = [
188+
"C",
189+
"N",
190+
"O",
191+
] # Data has type_map
192+
mock_data_system.return_value = mock_data_instance
193+
194+
# Mock the trainer
195+
mock_trainer_instance = MagicMock()
196+
mock_model = MagicMock()
197+
mock_model.get_type_map.return_value = [
198+
"H",
199+
"O",
200+
] # Model has different type_map
201+
mock_trainer_instance.model = mock_model
202+
mock_trainer_instance._change_energy_bias = MagicMock()
203+
mock_trainer_instance.save_checkpoint = MagicMock()
204+
mock_trainer_class.return_value = mock_trainer_instance
205+
206+
# Call change_bias function
207+
change_bias(
208+
INPUT=str(fake_checkpoint_dir),
209+
mode="change",
210+
system=str(fake_data_dir),
211+
)
212+
213+
# Verify that data's type_map was used (not model's)
214+
mock_trainer_instance._change_energy_bias.assert_called_once()
215+
args, kwargs = mock_trainer_instance._change_energy_bias.call_args
216+
# The third argument should be the type_map from data
217+
self.assertEqual(args[2], ["C", "N", "O"])
218+
92219

93220
if __name__ == "__main__":
94221
unittest.main()

0 commit comments

Comments
 (0)