|
4 | 4 | from pathlib import ( |
5 | 5 | Path, |
6 | 6 | ) |
| 7 | +from unittest.mock import ( |
| 8 | + MagicMock, |
| 9 | + patch, |
| 10 | +) |
7 | 11 |
|
8 | 12 | from deepmd.tf.entrypoints.change_bias import ( |
9 | 13 | change_bias, |
@@ -89,6 +93,129 @@ def test_change_bias_user_defined_not_implemented(self): |
89 | 93 | "User-defined bias setting is not yet implemented", str(cm.exception) |
90 | 94 | ) |
91 | 95 |
|
| 96 | + def test_change_bias_successful_execution(self): |
| 97 | + """Test successful bias changing execution path.""" |
| 98 | + # Create fake checkpoint directory with required files |
| 99 | + fake_checkpoint_dir = self.temp_path / "checkpoint" |
| 100 | + fake_checkpoint_dir.mkdir() |
| 101 | + (fake_checkpoint_dir / "checkpoint").write_text("fake checkpoint content") |
| 102 | + (fake_checkpoint_dir / "input.json").write_text('{"model": {}}') |
| 103 | + |
| 104 | + # Create fake data system |
| 105 | + fake_data_dir = self.temp_path / "data_system" |
| 106 | + fake_data_dir.mkdir() |
| 107 | + fake_set_dir = fake_data_dir / "set.000" |
| 108 | + fake_set_dir.mkdir() |
| 109 | + |
| 110 | + # Import the module properly |
| 111 | + import sys |
| 112 | + |
| 113 | + change_bias_module = sys.modules["deepmd.tf.entrypoints.change_bias"] |
| 114 | + |
| 115 | + with ( |
| 116 | + patch.object( |
| 117 | + change_bias_module, "expand_sys_str", return_value=[str(fake_data_dir)] |
| 118 | + ), |
| 119 | + patch.object(change_bias_module, "j_loader", return_value={"model": {}}), |
| 120 | + patch.object( |
| 121 | + change_bias_module, "update_deepmd_input", return_value={"model": {}} |
| 122 | + ), |
| 123 | + patch.object(change_bias_module, "normalize", return_value={"model": {}}), |
| 124 | + patch.object(change_bias_module, "DeepmdDataSystem") as mock_data_system, |
| 125 | + patch.object(change_bias_module, "DPTrainer") as mock_trainer_class, |
| 126 | + patch.object(change_bias_module, "shutil"), |
| 127 | + ): |
| 128 | + # Mock the data system |
| 129 | + mock_data_instance = MagicMock() |
| 130 | + mock_data_instance.get_type_map.return_value = ["H", "O"] |
| 131 | + mock_data_system.return_value = mock_data_instance |
| 132 | + |
| 133 | + # Mock the trainer |
| 134 | + mock_trainer_instance = MagicMock() |
| 135 | + mock_model = MagicMock() |
| 136 | + mock_model.get_type_map.return_value = ["H", "O"] |
| 137 | + mock_trainer_instance.model = mock_model |
| 138 | + mock_trainer_instance._change_energy_bias = MagicMock() |
| 139 | + mock_trainer_instance.save_checkpoint = MagicMock() |
| 140 | + mock_trainer_class.return_value = mock_trainer_instance |
| 141 | + |
| 142 | + # Call change_bias function |
| 143 | + change_bias( |
| 144 | + INPUT=str(fake_checkpoint_dir), |
| 145 | + mode="change", |
| 146 | + system=str(fake_data_dir), |
| 147 | + output=str(self.temp_path / "output"), |
| 148 | + ) |
| 149 | + |
| 150 | + # Verify that the trainer's change_energy_bias was called |
| 151 | + mock_trainer_instance._change_energy_bias.assert_called_once() |
| 152 | + |
| 153 | + def test_change_bias_with_data_type_map(self): |
| 154 | + """Test bias changing when data system has its own type_map.""" |
| 155 | + # Create fake checkpoint directory with required files |
| 156 | + fake_checkpoint_dir = self.temp_path / "checkpoint" |
| 157 | + fake_checkpoint_dir.mkdir() |
| 158 | + (fake_checkpoint_dir / "checkpoint").write_text("fake checkpoint content") |
| 159 | + (fake_checkpoint_dir / "input.json").write_text('{"model": {}}') |
| 160 | + |
| 161 | + # Create fake data system |
| 162 | + fake_data_dir = self.temp_path / "data_system" |
| 163 | + fake_data_dir.mkdir() |
| 164 | + fake_set_dir = fake_data_dir / "set.000" |
| 165 | + fake_set_dir.mkdir() |
| 166 | + |
| 167 | + # Import the module properly |
| 168 | + import sys |
| 169 | + |
| 170 | + change_bias_module = sys.modules["deepmd.tf.entrypoints.change_bias"] |
| 171 | + |
| 172 | + with ( |
| 173 | + patch.object( |
| 174 | + change_bias_module, "expand_sys_str", return_value=[str(fake_data_dir)] |
| 175 | + ), |
| 176 | + patch.object(change_bias_module, "j_loader", return_value={"model": {}}), |
| 177 | + patch.object( |
| 178 | + change_bias_module, "update_deepmd_input", return_value={"model": {}} |
| 179 | + ), |
| 180 | + patch.object(change_bias_module, "normalize", return_value={"model": {}}), |
| 181 | + patch.object(change_bias_module, "DeepmdDataSystem") as mock_data_system, |
| 182 | + patch.object(change_bias_module, "DPTrainer") as mock_trainer_class, |
| 183 | + patch.object(change_bias_module, "shutil"), |
| 184 | + ): |
| 185 | + # Mock the data system with type_map |
| 186 | + mock_data_instance = MagicMock() |
| 187 | + mock_data_instance.get_type_map.return_value = [ |
| 188 | + "C", |
| 189 | + "N", |
| 190 | + "O", |
| 191 | + ] # Data has type_map |
| 192 | + mock_data_system.return_value = mock_data_instance |
| 193 | + |
| 194 | + # Mock the trainer |
| 195 | + mock_trainer_instance = MagicMock() |
| 196 | + mock_model = MagicMock() |
| 197 | + mock_model.get_type_map.return_value = [ |
| 198 | + "H", |
| 199 | + "O", |
| 200 | + ] # Model has different type_map |
| 201 | + mock_trainer_instance.model = mock_model |
| 202 | + mock_trainer_instance._change_energy_bias = MagicMock() |
| 203 | + mock_trainer_instance.save_checkpoint = MagicMock() |
| 204 | + mock_trainer_class.return_value = mock_trainer_instance |
| 205 | + |
| 206 | + # Call change_bias function |
| 207 | + change_bias( |
| 208 | + INPUT=str(fake_checkpoint_dir), |
| 209 | + mode="change", |
| 210 | + system=str(fake_data_dir), |
| 211 | + ) |
| 212 | + |
| 213 | + # Verify that data's type_map was used (not model's) |
| 214 | + mock_trainer_instance._change_energy_bias.assert_called_once() |
| 215 | + args, kwargs = mock_trainer_instance._change_energy_bias.call_args |
| 216 | + # The third argument should be the type_map from data |
| 217 | + self.assertEqual(args[2], ["C", "N", "O"]) |
| 218 | + |
92 | 219 |
|
93 | 220 | if __name__ == "__main__": |
94 | 221 | unittest.main() |
0 commit comments