Skip to content

Commit d6aeacc

Browse files
committed
adopt
1 parent c608729 commit d6aeacc

2 files changed

Lines changed: 179 additions & 129 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -750,26 +750,32 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
750750
def _log_model_summary(self) -> None:
751751
"""Log model summary information including descriptor type and parameter count."""
752752

753-
def get_descriptor_type(model: Any) -> str:
753+
def get_descriptor_type(model: torch.nn.Module) -> str:
754754
"""Get the descriptor type name from model."""
755755
# Standard models have get_descriptor method
756756
if hasattr(model, "get_descriptor"):
757757
descriptor = model.get_descriptor()
758-
serialized = descriptor.serialize()
759-
if isinstance(serialized, dict) and "type" in serialized:
760-
return serialized["type"].upper()
761-
# ZBL models: descriptor is in atomic_model.models[0]
762-
if hasattr(model, "atomic_model") and hasattr(model.atomic_model, "models"):
763-
models = model.atomic_model.models
764-
if models: # Check non-empty
765-
dp_model = models[0]
766-
if hasattr(dp_model, "descriptor"):
767-
serialized = dp_model.descriptor.serialize()
768-
if isinstance(serialized, dict) and "type" in serialized:
769-
return serialized["type"].upper() + " (with ZBL)"
758+
if descriptor is not None:
759+
serialized = descriptor.serialize()
760+
if isinstance(serialized, dict) and "type" in serialized:
761+
return serialized["type"].upper()
762+
# ZBL and other models: use serialize() API
763+
if hasattr(model, "serialize"):
764+
serialized = model.serialize()
765+
if isinstance(serialized, dict):
766+
model_type = serialized.get("type", "")
767+
if model_type == "zbl":
768+
# ZBL model: get descriptor type from the DP sub-model
769+
models_data = serialized.get("models", [])
770+
if models_data:
771+
descriptor_data = models_data[0].get("descriptor", {})
772+
if isinstance(descriptor_data, dict):
773+
desc_type = descriptor_data.get("type", "UNKNOWN")
774+
return f"{desc_type.upper()} (with ZBL)"
775+
return "UNKNOWN (with ZBL)"
770776
return "UNKNOWN"
771777

772-
def count_parameters(model: Any) -> int:
778+
def count_parameters(model: torch.nn.Module) -> int:
773779
"""Count the total number of trainable parameters."""
774780
return sum(p.numel() for p in model.parameters() if p.requires_grad)
775781

source/tests/pt/test_model_summary.py

Lines changed: 159 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -4,140 +4,184 @@
44
import unittest
55
from unittest.mock import (
66
MagicMock,
7+
patch,
78
)
89

910
import torch
1011

1112

12-
class TestGetDescriptorType(unittest.TestCase):
13-
"""Test get_descriptor_type helper function."""
14-
15-
@staticmethod
16-
def get_descriptor_type(model):
17-
"""Replicate the logic from training.py for testing."""
18-
# Standard models have get_descriptor method
19-
if hasattr(model, "get_descriptor"):
20-
descriptor = model.get_descriptor()
21-
serialized = descriptor.serialize()
22-
if isinstance(serialized, dict) and "type" in serialized:
23-
return serialized["type"].upper()
24-
# ZBL models: descriptor is in atomic_model.models[0]
25-
if hasattr(model, "atomic_model") and hasattr(model.atomic_model, "models"):
26-
models = model.atomic_model.models
27-
if models: # Check non-empty
28-
dp_model = models[0]
29-
if hasattr(dp_model, "descriptor"):
30-
serialized = dp_model.descriptor.serialize()
31-
if isinstance(serialized, dict) and "type" in serialized:
32-
return serialized["type"].upper() + " (with ZBL)"
33-
return "UNKNOWN"
34-
35-
def test_standard_model(self):
36-
"""Test descriptor type detection for standard models."""
37-
mock_descriptor = MagicMock()
38-
mock_descriptor.serialize.return_value = {"type": "se_e2_a"}
39-
40-
mock_model = MagicMock()
41-
mock_model.get_descriptor.return_value = mock_descriptor
13+
class TestLogModelSummary(unittest.TestCase):
14+
"""Test _log_model_summary method behavior."""
4215

43-
result = self.get_descriptor_type(mock_model)
44-
self.assertEqual(result, "SE_E2_A")
16+
def _create_mock_trainer(self, multi_task: bool = False):
17+
"""Create a mock Trainer instance for testing."""
18+
trainer = MagicMock()
19+
trainer.multi_task = multi_task
20+
trainer.rank = 0
21+
return trainer
4522

46-
def test_zbl_model(self):
47-
"""Test descriptor type detection for ZBL models."""
23+
def _create_mock_model_with_descriptor(self, desc_type: str):
24+
"""Create a mock model with get_descriptor method."""
4825
mock_descriptor = MagicMock()
49-
mock_descriptor.serialize.return_value = {"type": "dpa1"}
50-
51-
mock_dp_model = MagicMock()
52-
mock_dp_model.descriptor = mock_descriptor
53-
54-
mock_atomic_model = MagicMock()
55-
mock_atomic_model.models = [mock_dp_model]
26+
mock_descriptor.serialize.return_value = {"type": desc_type}
5627

57-
mock_model = MagicMock(spec=[]) # No get_descriptor
58-
mock_model.atomic_model = mock_atomic_model
59-
60-
result = self.get_descriptor_type(mock_model)
61-
self.assertEqual(result, "DPA1 (with ZBL)")
62-
63-
def test_empty_models_list(self):
64-
"""Test handling of empty models list in ZBL model."""
65-
mock_atomic_model = MagicMock()
66-
mock_atomic_model.models = []
28+
mock_model = MagicMock(spec=torch.nn.Module)
29+
mock_model.get_descriptor.return_value = mock_descriptor
30+
mock_model.parameters.return_value = iter(
31+
[torch.nn.Parameter(torch.randn(10, 5))]
32+
)
33+
return mock_model
34+
35+
def _create_mock_zbl_model(self, desc_type: str):
36+
"""Create a mock ZBL model using serialize() API."""
37+
mock_model = MagicMock(spec=torch.nn.Module)
38+
# Remove get_descriptor to simulate ZBL model
39+
del mock_model.get_descriptor
40+
mock_model.serialize.return_value = {
41+
"type": "zbl",
42+
"models": [
43+
{"descriptor": {"type": desc_type}},
44+
{"type": "pairtab"},
45+
],
46+
}
47+
mock_model.parameters.return_value = iter(
48+
[torch.nn.Parameter(torch.randn(10, 5))]
49+
)
50+
return mock_model
51+
52+
@patch("deepmd.pt.train.training.log")
53+
def test_standard_model_log_output(self, mock_log):
54+
"""Test log output for standard models."""
55+
from deepmd.pt.train.training import (
56+
Trainer,
57+
)
58+
59+
trainer = self._create_mock_trainer(multi_task=False)
60+
trainer.model = self._create_mock_model_with_descriptor("se_e2_a")
61+
62+
# Call the actual method
63+
Trainer._log_model_summary(trainer)
64+
65+
# Verify log.info was called with expected descriptor type
66+
calls = [str(call) for call in mock_log.info.call_args_list]
67+
self.assertTrue(any("SE_E2_A" in call for call in calls))
68+
self.assertTrue(any("Model Params" in call for call in calls))
69+
70+
@patch("deepmd.pt.train.training.log")
71+
def test_zbl_model_log_output(self, mock_log):
72+
"""Test log output for ZBL models."""
73+
from deepmd.pt.train.training import (
74+
Trainer,
75+
)
76+
77+
trainer = self._create_mock_trainer(multi_task=False)
78+
trainer.model = self._create_mock_zbl_model("dpa1")
79+
80+
# Call the actual method
81+
Trainer._log_model_summary(trainer)
82+
83+
# Verify log.info was called with expected descriptor type
84+
calls = [str(call) for call in mock_log.info.call_args_list]
85+
self.assertTrue(any("DPA1 (with ZBL)" in call for call in calls))
86+
87+
@patch("deepmd.pt.train.training.log")
88+
def test_multi_task_log_output(self, mock_log):
89+
"""Test log output for multi-task models."""
90+
from deepmd.pt.train.training import (
91+
Trainer,
92+
)
93+
94+
trainer = self._create_mock_trainer(multi_task=True)
95+
trainer.model_keys = ["task1", "task2"]
96+
trainer.model = {
97+
"task1": self._create_mock_model_with_descriptor("dpa2"),
98+
"task2": self._create_mock_model_with_descriptor("se_atten"),
99+
}
100+
101+
# Call the actual method
102+
Trainer._log_model_summary(trainer)
103+
104+
# Verify log.info was called for each task
105+
calls = [str(call) for call in mock_log.info.call_args_list]
106+
self.assertTrue(any("task1" in call for call in calls))
107+
self.assertTrue(any("task2" in call for call in calls))
108+
self.assertTrue(any("DPA2" in call for call in calls))
109+
self.assertTrue(any("SE_ATTEN" in call for call in calls))
110+
111+
@patch("deepmd.pt.train.training.log")
112+
def test_unknown_model_structure(self, mock_log):
113+
"""Test handling of unknown model structure."""
114+
from deepmd.pt.train.training import (
115+
Trainer,
116+
)
117+
118+
trainer = self._create_mock_trainer(multi_task=False)
119+
# Model without get_descriptor and without serialize returning valid type
120+
mock_model = MagicMock(spec=torch.nn.Module)
121+
del mock_model.get_descriptor
122+
mock_model.serialize.return_value = {"other_key": "value"}
123+
mock_model.parameters.return_value = iter([])
124+
trainer.model = mock_model
125+
126+
# Call the actual method
127+
Trainer._log_model_summary(trainer)
128+
129+
# Verify "UNKNOWN" appears in output
130+
calls = [str(call) for call in mock_log.info.call_args_list]
131+
self.assertTrue(any("UNKNOWN" in call for call in calls))
132+
133+
@patch("deepmd.pt.train.training.log")
134+
def test_none_descriptor(self, mock_log):
135+
"""Test handling when get_descriptor returns None."""
136+
from deepmd.pt.train.training import (
137+
Trainer,
138+
)
139+
140+
trainer = self._create_mock_trainer(multi_task=False)
141+
mock_model = MagicMock(spec=torch.nn.Module)
142+
mock_model.get_descriptor.return_value = None
143+
mock_model.serialize.return_value = {"other_key": "value"}
144+
mock_model.parameters.return_value = iter([])
145+
trainer.model = mock_model
146+
147+
# Call the actual method - should not raise AttributeError
148+
Trainer._log_model_summary(trainer)
149+
150+
# Verify "UNKNOWN" appears in output
151+
calls = [str(call) for call in mock_log.info.call_args_list]
152+
self.assertTrue(any("UNKNOWN" in call for call in calls))
67153

68-
mock_model = MagicMock(spec=[])
69-
mock_model.atomic_model = mock_atomic_model
70154

71-
result = self.get_descriptor_type(mock_model)
72-
self.assertEqual(result, "UNKNOWN")
155+
class TestCountParameters(unittest.TestCase):
156+
"""Test parameter counting behavior through _log_model_summary."""
73157

74-
def test_missing_type_key(self):
75-
"""Test handling of serialize() without 'type' key."""
76-
mock_descriptor = MagicMock()
77-
mock_descriptor.serialize.return_value = {"other_key": "value"}
158+
@patch("deepmd.pt.train.training.log")
159+
def test_parameter_count_in_log(self, mock_log):
160+
"""Test that parameter count is correctly logged."""
161+
from deepmd.pt.train.training import (
162+
Trainer,
163+
)
78164

79-
mock_model = MagicMock()
80-
mock_model.get_descriptor.return_value = mock_descriptor
165+
trainer = MagicMock()
166+
trainer.multi_task = False
167+
trainer.rank = 0
81168

82-
result = self.get_descriptor_type(mock_model)
83-
self.assertEqual(result, "UNKNOWN")
169+
# Create model with known parameter count
170+
real_model = torch.nn.Linear(10, 5).to("cpu") # 10*5 + 5 = 55 parameters
84171

85-
def test_serialize_returns_non_dict(self):
86-
"""Test handling of serialize() returning non-dict."""
172+
# Add mock methods
87173
mock_descriptor = MagicMock()
88-
mock_descriptor.serialize.return_value = "not_a_dict"
89-
90-
mock_model = MagicMock()
91-
mock_model.get_descriptor.return_value = mock_descriptor
92-
93-
result = self.get_descriptor_type(mock_model)
94-
self.assertEqual(result, "UNKNOWN")
174+
mock_descriptor.serialize.return_value = {"type": "test"}
175+
real_model.get_descriptor = MagicMock(return_value=mock_descriptor)
95176

96-
def test_unknown_model_structure(self):
97-
"""Test handling of unknown model structure."""
98-
mock_model = MagicMock(spec=[]) # No get_descriptor, no atomic_model
99-
result = self.get_descriptor_type(mock_model)
100-
self.assertEqual(result, "UNKNOWN")
177+
trainer.model = real_model
101178

179+
# Call the actual method
180+
Trainer._log_model_summary(trainer)
102181

103-
class TestCountParameters(unittest.TestCase):
104-
"""Test count_parameters helper function."""
105-
106-
@staticmethod
107-
def count_parameters(model):
108-
"""Replicate the logic from training.py for testing."""
109-
return sum(p.numel() for p in model.parameters() if p.requires_grad)
110-
111-
def test_all_trainable(self):
112-
"""Test counting when all parameters are trainable."""
113-
with torch.device("cpu"):
114-
model = torch.nn.Linear(10, 5) # 10*5 + 5 = 55 parameters
115-
result = self.count_parameters(model)
116-
self.assertEqual(result, 55)
117-
118-
def test_mixed_trainable(self):
119-
"""Test counting with some frozen parameters."""
120-
with torch.device("cpu"):
121-
model = torch.nn.Sequential(
122-
torch.nn.Linear(10, 5), # 55 params
123-
torch.nn.Linear(5, 3), # 18 params
124-
)
125-
# Freeze first layer
126-
for param in model[0].parameters():
127-
param.requires_grad = False
128-
129-
result = self.count_parameters(model)
130-
self.assertEqual(result, 18) # Only second layer
131-
132-
def test_all_frozen(self):
133-
"""Test counting when all parameters are frozen."""
134-
with torch.device("cpu"):
135-
model = torch.nn.Linear(10, 5)
136-
for param in model.parameters():
137-
param.requires_grad = False
138-
139-
result = self.count_parameters(model)
140-
self.assertEqual(result, 0)
182+
# Verify parameter count is logged (55 params = 0.000055 M)
183+
calls = [str(call) for call in mock_log.info.call_args_list]
184+
self.assertTrue(any("0.000 M" in call for call in calls))
141185

142186

143187
if __name__ == "__main__":

0 commit comments

Comments
 (0)