|
4 | 4 | import unittest |
5 | 5 | from unittest.mock import ( |
6 | 6 | MagicMock, |
| 7 | + patch, |
7 | 8 | ) |
8 | 9 |
|
9 | 10 | import torch |
10 | 11 |
|
11 | 12 |
|
12 | | -class TestGetDescriptorType(unittest.TestCase): |
13 | | - """Test get_descriptor_type helper function.""" |
14 | | - |
15 | | - @staticmethod |
16 | | - def get_descriptor_type(model): |
17 | | - """Replicate the logic from training.py for testing.""" |
18 | | - # Standard models have get_descriptor method |
19 | | - if hasattr(model, "get_descriptor"): |
20 | | - descriptor = model.get_descriptor() |
21 | | - serialized = descriptor.serialize() |
22 | | - if isinstance(serialized, dict) and "type" in serialized: |
23 | | - return serialized["type"].upper() |
24 | | - # ZBL models: descriptor is in atomic_model.models[0] |
25 | | - if hasattr(model, "atomic_model") and hasattr(model.atomic_model, "models"): |
26 | | - models = model.atomic_model.models |
27 | | - if models: # Check non-empty |
28 | | - dp_model = models[0] |
29 | | - if hasattr(dp_model, "descriptor"): |
30 | | - serialized = dp_model.descriptor.serialize() |
31 | | - if isinstance(serialized, dict) and "type" in serialized: |
32 | | - return serialized["type"].upper() + " (with ZBL)" |
33 | | - return "UNKNOWN" |
34 | | - |
35 | | - def test_standard_model(self): |
36 | | - """Test descriptor type detection for standard models.""" |
37 | | - mock_descriptor = MagicMock() |
38 | | - mock_descriptor.serialize.return_value = {"type": "se_e2_a"} |
39 | | - |
40 | | - mock_model = MagicMock() |
41 | | - mock_model.get_descriptor.return_value = mock_descriptor |
| 13 | +class TestLogModelSummary(unittest.TestCase): |
| 14 | + """Test _log_model_summary method behavior.""" |
42 | 15 |
|
43 | | - result = self.get_descriptor_type(mock_model) |
44 | | - self.assertEqual(result, "SE_E2_A") |
| 16 | + def _create_mock_trainer(self, multi_task: bool = False): |
| 17 | + """Create a mock Trainer instance for testing.""" |
| 18 | + trainer = MagicMock() |
| 19 | + trainer.multi_task = multi_task |
| 20 | + trainer.rank = 0 |
| 21 | + return trainer |
45 | 22 |
|
46 | | - def test_zbl_model(self): |
47 | | - """Test descriptor type detection for ZBL models.""" |
| 23 | + def _create_mock_model_with_descriptor(self, desc_type: str): |
| 24 | + """Create a mock model with get_descriptor method.""" |
48 | 25 | mock_descriptor = MagicMock() |
49 | | - mock_descriptor.serialize.return_value = {"type": "dpa1"} |
50 | | - |
51 | | - mock_dp_model = MagicMock() |
52 | | - mock_dp_model.descriptor = mock_descriptor |
53 | | - |
54 | | - mock_atomic_model = MagicMock() |
55 | | - mock_atomic_model.models = [mock_dp_model] |
| 26 | + mock_descriptor.serialize.return_value = {"type": desc_type} |
56 | 27 |
|
57 | | - mock_model = MagicMock(spec=[]) # No get_descriptor |
58 | | - mock_model.atomic_model = mock_atomic_model |
59 | | - |
60 | | - result = self.get_descriptor_type(mock_model) |
61 | | - self.assertEqual(result, "DPA1 (with ZBL)") |
62 | | - |
63 | | - def test_empty_models_list(self): |
64 | | - """Test handling of empty models list in ZBL model.""" |
65 | | - mock_atomic_model = MagicMock() |
66 | | - mock_atomic_model.models = [] |
| 28 | + mock_model = MagicMock(spec=torch.nn.Module) |
| 29 | + mock_model.get_descriptor.return_value = mock_descriptor |
| 30 | + mock_model.parameters.return_value = iter( |
| 31 | + [torch.nn.Parameter(torch.randn(10, 5))] |
| 32 | + ) |
| 33 | + return mock_model |
| 34 | + |
| 35 | + def _create_mock_zbl_model(self, desc_type: str): |
| 36 | + """Create a mock ZBL model using serialize() API.""" |
| 37 | + mock_model = MagicMock(spec=torch.nn.Module) |
| 38 | + # Remove get_descriptor to simulate ZBL model |
| 39 | + del mock_model.get_descriptor |
| 40 | + mock_model.serialize.return_value = { |
| 41 | + "type": "zbl", |
| 42 | + "models": [ |
| 43 | + {"descriptor": {"type": desc_type}}, |
| 44 | + {"type": "pairtab"}, |
| 45 | + ], |
| 46 | + } |
| 47 | + mock_model.parameters.return_value = iter( |
| 48 | + [torch.nn.Parameter(torch.randn(10, 5))] |
| 49 | + ) |
| 50 | + return mock_model |
| 51 | + |
| 52 | + @patch("deepmd.pt.train.training.log") |
| 53 | + def test_standard_model_log_output(self, mock_log): |
| 54 | + """Test log output for standard models.""" |
| 55 | + from deepmd.pt.train.training import ( |
| 56 | + Trainer, |
| 57 | + ) |
| 58 | + |
| 59 | + trainer = self._create_mock_trainer(multi_task=False) |
| 60 | + trainer.model = self._create_mock_model_with_descriptor("se_e2_a") |
| 61 | + |
| 62 | + # Call the actual method |
| 63 | + Trainer._log_model_summary(trainer) |
| 64 | + |
| 65 | + # Verify log.info was called with expected descriptor type |
| 66 | + calls = [str(call) for call in mock_log.info.call_args_list] |
| 67 | + self.assertTrue(any("SE_E2_A" in call for call in calls)) |
| 68 | + self.assertTrue(any("Model Params" in call for call in calls)) |
| 69 | + |
| 70 | + @patch("deepmd.pt.train.training.log") |
| 71 | + def test_zbl_model_log_output(self, mock_log): |
| 72 | + """Test log output for ZBL models.""" |
| 73 | + from deepmd.pt.train.training import ( |
| 74 | + Trainer, |
| 75 | + ) |
| 76 | + |
| 77 | + trainer = self._create_mock_trainer(multi_task=False) |
| 78 | + trainer.model = self._create_mock_zbl_model("dpa1") |
| 79 | + |
| 80 | + # Call the actual method |
| 81 | + Trainer._log_model_summary(trainer) |
| 82 | + |
| 83 | + # Verify log.info was called with expected descriptor type |
| 84 | + calls = [str(call) for call in mock_log.info.call_args_list] |
| 85 | + self.assertTrue(any("DPA1 (with ZBL)" in call for call in calls)) |
| 86 | + |
| 87 | + @patch("deepmd.pt.train.training.log") |
| 88 | + def test_multi_task_log_output(self, mock_log): |
| 89 | + """Test log output for multi-task models.""" |
| 90 | + from deepmd.pt.train.training import ( |
| 91 | + Trainer, |
| 92 | + ) |
| 93 | + |
| 94 | + trainer = self._create_mock_trainer(multi_task=True) |
| 95 | + trainer.model_keys = ["task1", "task2"] |
| 96 | + trainer.model = { |
| 97 | + "task1": self._create_mock_model_with_descriptor("dpa2"), |
| 98 | + "task2": self._create_mock_model_with_descriptor("se_atten"), |
| 99 | + } |
| 100 | + |
| 101 | + # Call the actual method |
| 102 | + Trainer._log_model_summary(trainer) |
| 103 | + |
| 104 | + # Verify log.info was called for each task |
| 105 | + calls = [str(call) for call in mock_log.info.call_args_list] |
| 106 | + self.assertTrue(any("task1" in call for call in calls)) |
| 107 | + self.assertTrue(any("task2" in call for call in calls)) |
| 108 | + self.assertTrue(any("DPA2" in call for call in calls)) |
| 109 | + self.assertTrue(any("SE_ATTEN" in call for call in calls)) |
| 110 | + |
| 111 | + @patch("deepmd.pt.train.training.log") |
| 112 | + def test_unknown_model_structure(self, mock_log): |
| 113 | + """Test handling of unknown model structure.""" |
| 114 | + from deepmd.pt.train.training import ( |
| 115 | + Trainer, |
| 116 | + ) |
| 117 | + |
| 118 | + trainer = self._create_mock_trainer(multi_task=False) |
| 119 | + # Model without get_descriptor and without serialize returning valid type |
| 120 | + mock_model = MagicMock(spec=torch.nn.Module) |
| 121 | + del mock_model.get_descriptor |
| 122 | + mock_model.serialize.return_value = {"other_key": "value"} |
| 123 | + mock_model.parameters.return_value = iter([]) |
| 124 | + trainer.model = mock_model |
| 125 | + |
| 126 | + # Call the actual method |
| 127 | + Trainer._log_model_summary(trainer) |
| 128 | + |
| 129 | + # Verify "UNKNOWN" appears in output |
| 130 | + calls = [str(call) for call in mock_log.info.call_args_list] |
| 131 | + self.assertTrue(any("UNKNOWN" in call for call in calls)) |
| 132 | + |
| 133 | + @patch("deepmd.pt.train.training.log") |
| 134 | + def test_none_descriptor(self, mock_log): |
| 135 | + """Test handling when get_descriptor returns None.""" |
| 136 | + from deepmd.pt.train.training import ( |
| 137 | + Trainer, |
| 138 | + ) |
| 139 | + |
| 140 | + trainer = self._create_mock_trainer(multi_task=False) |
| 141 | + mock_model = MagicMock(spec=torch.nn.Module) |
| 142 | + mock_model.get_descriptor.return_value = None |
| 143 | + mock_model.serialize.return_value = {"other_key": "value"} |
| 144 | + mock_model.parameters.return_value = iter([]) |
| 145 | + trainer.model = mock_model |
| 146 | + |
| 147 | + # Call the actual method - should not raise AttributeError |
| 148 | + Trainer._log_model_summary(trainer) |
| 149 | + |
| 150 | + # Verify "UNKNOWN" appears in output |
| 151 | + calls = [str(call) for call in mock_log.info.call_args_list] |
| 152 | + self.assertTrue(any("UNKNOWN" in call for call in calls)) |
67 | 153 |
|
68 | | - mock_model = MagicMock(spec=[]) |
69 | | - mock_model.atomic_model = mock_atomic_model |
70 | 154 |
|
71 | | - result = self.get_descriptor_type(mock_model) |
72 | | - self.assertEqual(result, "UNKNOWN") |
| 155 | +class TestCountParameters(unittest.TestCase): |
| 156 | + """Test parameter counting behavior through _log_model_summary.""" |
73 | 157 |
|
74 | | - def test_missing_type_key(self): |
75 | | - """Test handling of serialize() without 'type' key.""" |
76 | | - mock_descriptor = MagicMock() |
77 | | - mock_descriptor.serialize.return_value = {"other_key": "value"} |
| 158 | + @patch("deepmd.pt.train.training.log") |
| 159 | + def test_parameter_count_in_log(self, mock_log): |
| 160 | + """Test that parameter count is correctly logged.""" |
| 161 | + from deepmd.pt.train.training import ( |
| 162 | + Trainer, |
| 163 | + ) |
78 | 164 |
|
79 | | - mock_model = MagicMock() |
80 | | - mock_model.get_descriptor.return_value = mock_descriptor |
| 165 | + trainer = MagicMock() |
| 166 | + trainer.multi_task = False |
| 167 | + trainer.rank = 0 |
81 | 168 |
|
82 | | - result = self.get_descriptor_type(mock_model) |
83 | | - self.assertEqual(result, "UNKNOWN") |
| 169 | + # Create model with known parameter count |
| 170 | + real_model = torch.nn.Linear(10, 5).to("cpu") # 10*5 + 5 = 55 parameters |
84 | 171 |
|
85 | | - def test_serialize_returns_non_dict(self): |
86 | | - """Test handling of serialize() returning non-dict.""" |
| 172 | + # Add mock methods |
87 | 173 | mock_descriptor = MagicMock() |
88 | | - mock_descriptor.serialize.return_value = "not_a_dict" |
89 | | - |
90 | | - mock_model = MagicMock() |
91 | | - mock_model.get_descriptor.return_value = mock_descriptor |
92 | | - |
93 | | - result = self.get_descriptor_type(mock_model) |
94 | | - self.assertEqual(result, "UNKNOWN") |
| 174 | + mock_descriptor.serialize.return_value = {"type": "test"} |
| 175 | + real_model.get_descriptor = MagicMock(return_value=mock_descriptor) |
95 | 176 |
|
96 | | - def test_unknown_model_structure(self): |
97 | | - """Test handling of unknown model structure.""" |
98 | | - mock_model = MagicMock(spec=[]) # No get_descriptor, no atomic_model |
99 | | - result = self.get_descriptor_type(mock_model) |
100 | | - self.assertEqual(result, "UNKNOWN") |
| 177 | + trainer.model = real_model |
101 | 178 |
|
| 179 | + # Call the actual method |
| 180 | + Trainer._log_model_summary(trainer) |
102 | 181 |
|
103 | | -class TestCountParameters(unittest.TestCase): |
104 | | - """Test count_parameters helper function.""" |
105 | | - |
106 | | - @staticmethod |
107 | | - def count_parameters(model): |
108 | | - """Replicate the logic from training.py for testing.""" |
109 | | - return sum(p.numel() for p in model.parameters() if p.requires_grad) |
110 | | - |
111 | | - def test_all_trainable(self): |
112 | | - """Test counting when all parameters are trainable.""" |
113 | | - with torch.device("cpu"): |
114 | | - model = torch.nn.Linear(10, 5) # 10*5 + 5 = 55 parameters |
115 | | - result = self.count_parameters(model) |
116 | | - self.assertEqual(result, 55) |
117 | | - |
118 | | - def test_mixed_trainable(self): |
119 | | - """Test counting with some frozen parameters.""" |
120 | | - with torch.device("cpu"): |
121 | | - model = torch.nn.Sequential( |
122 | | - torch.nn.Linear(10, 5), # 55 params |
123 | | - torch.nn.Linear(5, 3), # 18 params |
124 | | - ) |
125 | | - # Freeze first layer |
126 | | - for param in model[0].parameters(): |
127 | | - param.requires_grad = False |
128 | | - |
129 | | - result = self.count_parameters(model) |
130 | | - self.assertEqual(result, 18) # Only second layer |
131 | | - |
132 | | - def test_all_frozen(self): |
133 | | - """Test counting when all parameters are frozen.""" |
134 | | - with torch.device("cpu"): |
135 | | - model = torch.nn.Linear(10, 5) |
136 | | - for param in model.parameters(): |
137 | | - param.requires_grad = False |
138 | | - |
139 | | - result = self.count_parameters(model) |
140 | | - self.assertEqual(result, 0) |
| 182 | + # Verify parameter count is logged (55 params = 0.000055 M) |
| 183 | + calls = [str(call) for call in mock_log.info.call_args_list] |
| 184 | + self.assertTrue(any("0.000 M" in call for call in calls)) |
141 | 185 |
|
142 | 186 |
|
143 | 187 | if __name__ == "__main__": |
|
0 commit comments