Skip to content

Commit c834131

Browse files
Add unittest
1 parent 546e662 commit c834131

5 files changed

Lines changed: 54 additions & 7 deletions

File tree

deepmd/dpmodel/fitting/property_fitting.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,13 @@ def __init__(
8888
exclude_types: list[int] = [],
8989
type_map: list[str] | None = None,
9090
default_fparam: list | None = None,
91+
distinguish_types: bool = True,
9192
# not used
9293
seed: int | None = None,
9394
) -> None:
9495
self.task_dim = task_dim
9596
self.intensive = intensive
97+
self.distinguish_types = distinguish_types
9698
super().__init__(
9799
var_name=property_name,
98100
ntypes=ntypes,
@@ -131,7 +133,7 @@ def output_def(self) -> FittingOutputDef:
131133
@classmethod
132134
def deserialize(cls, data: dict) -> "PropertyFittingNet":
133135
data = data.copy()
134-
check_version_compatibility(data.pop("@version"), 5, 1)
136+
check_version_compatibility(data.pop("@version"), 6, 1)
135137
data.pop("dim_out")
136138
data["property_name"] = data.pop("var_name")
137139
data.pop("tot_ener_zero")
@@ -150,7 +152,12 @@ def serialize(self) -> dict:
150152
"type": "property",
151153
"task_dim": self.task_dim,
152154
"intensive": self.intensive,
155+
"distinguish_types": self.distinguish_types,
153156
}
154-
dd["@version"] = 5
157+
dd["@version"] = 6
155158

156159
return dd
160+
161+
def get_distinguish_types(self) -> bool:
162+
"""Get whether the fitting net computes stats which are distinguished between different types of atoms."""
163+
return self.distinguish_types

deepmd/pt/model/atomic_model/property_atomic_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626

2727
def get_compute_stats_distinguish_types(self) -> bool:
2828
"""Get whether the fitting net computes stats which are not distinguished between different types of atoms."""
29-
return True
29+
return self.fitting_net.get_distinguish_types()
3030

3131
def get_intensive(self) -> bool:
3232
"""Whether the fitting property is intensive."""

deepmd/pt/model/task/property.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class PropertyFittingNet(InvarFitting):
7070
different fitting nets for different atom types.
7171
seed : int, optional
7272
Random seed.
73+
distinguish_types : bool
74+
Whether to distinguish atom types when computing output statistics.
7375
"""
7476

7577
def __init__(
@@ -91,10 +93,12 @@ def __init__(
9193
trainable: bool | list[bool] = True,
9294
seed: int | None = None,
9395
default_fparam: list | None = None,
96+
distinguish_types: bool = True,
9497
**kwargs: Any,
9598
) -> None:
9699
self.task_dim = task_dim
97100
self.intensive = intensive
101+
self.distinguish_types = distinguish_types
98102
super().__init__(
99103
var_name=property_name,
100104
ntypes=ntypes,
@@ -133,10 +137,14 @@ def get_intensive(self) -> bool:
133137
"""Whether the fitting property is intensive."""
134138
return self.intensive
135139

140+
def get_distinguish_types(self) -> bool:
141+
"""Get whether to distinguish atom types when computing output statistics."""
142+
return self.distinguish_types
143+
136144
@classmethod
137145
def deserialize(cls, data: dict) -> "PropertyFittingNet":
138146
data = data.copy()
139-
check_version_compatibility(data.pop("@version", 1), 5, 1)
147+
check_version_compatibility(data.pop("@version", 1), 6, 1)
140148
data.pop("dim_out")
141149
data["property_name"] = data.pop("var_name")
142150
obj = super().deserialize(data)
@@ -150,8 +158,9 @@ def serialize(self) -> dict:
150158
"type": "property",
151159
"task_dim": self.task_dim,
152160
"intensive": self.intensive,
161+
"distinguish_types": self.distinguish_types,
153162
}
154-
dd["@version"] = 5
163+
dd["@version"] = 6
155164

156165
return dd
157166

deepmd/utils/argcheck.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,6 +1925,7 @@ def fitting_property() -> list[Argument]:
19251925
doc_seed = "Random seed for parameter initialization of the fitting net"
19261926
doc_task_dim = "The dimension of outputs of fitting net"
19271927
doc_intensive = "Whether the fitting property is intensive"
1928+
doc_distinguish_types = "Whether to distinguish atom types when computing output statistics."
19281929
doc_property_name = "The names of fitting property, which should be consistent with the property name in the dataset."
19291930
doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n\n\
19301931
- bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\
@@ -1966,6 +1967,7 @@ def fitting_property() -> list[Argument]:
19661967
Argument("seed", [int, None], optional=True, doc=doc_seed),
19671968
Argument("task_dim", int, optional=True, default=1, doc=doc_task_dim),
19681969
Argument("intensive", bool, optional=True, default=False, doc=doc_intensive),
1970+
Argument("distinguish_types", bool, optional=True, default=True, doc=doc_distinguish_types),
19691971
Argument(
19701972
"property_name",
19711973
str,

source/tests/common/test_out_stat.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_compute_stats_from_redu_with_assigned_bias(self) -> None:
9191
)
9292

9393
def test_compute_stats_do_not_distinguish_types_intensive(self) -> None:
94-
"""Test compute_stats_property function with intensive scenario."""
94+
"""Test compute_stats_do_not_distinguish function with intensive scenario."""
9595
bias, std = compute_stats_do_not_distinguish_types(
9696
self.output_redu, self.natoms, intensive=True
9797
)
@@ -110,7 +110,7 @@ def test_compute_stats_do_not_distinguish_types_intensive(self) -> None:
110110
)
111111

112112
def test_compute_stats_do_not_distinguish_types_extensive(self) -> None:
113-
"""Test compute_stats_property function with extensive scenario."""
113+
"""Test compute_stats_do_not_distinguish function with extensive scenario."""
114114
bias, std = compute_stats_do_not_distinguish_types(
115115
self.output_redu, self.natoms
116116
)
@@ -142,6 +142,35 @@ def test_compute_stats_do_not_distinguish_types_extensive(self) -> None:
142142
rtol=1e-7,
143143
)
144144

145+
def test_compute_stats_from_redu_intensive(self) -> None:
146+
"""Test compute_stats_from_redu function with intensive scenario."""
147+
bias, std = compute_stats_from_redu(
148+
self.output_redu, self.natoms, intensive=True,
149+
)
150+
# Test shapes
151+
assert bias.shape == (len(self.mean), self.output_redu.shape[1])
152+
assert std.shape == (self.output_redu.shape[1],)
153+
154+
# Test values
155+
np.testing.assert_allclose(
156+
bias,
157+
np.array(
158+
[
159+
[8926338.68432182, 8750110.71559034, 2045325.12109175, 1392024.84192495, 6714978.25878314],
160+
[554163.59820041, 5965821.3924394, 2171555.69509784, 8050760.64873761, 5277414.78728998],
161+
[9180265.02004177, 6836013.36530394, 9121797.79540738, 7801570.3259364, 4095707.84597587]
162+
]
163+
),
164+
rtol=1e-6,
165+
)
166+
np.testing.assert_allclose(
167+
std,
168+
np.array(
169+
[0.01700638, 0.01954897, 0.02028186, 0.01074124, 0.02025821]
170+
),
171+
rtol=1e-6,
172+
)
173+
145174
def test_compute_stats_from_atomic(self) -> None:
146175
bias, std = compute_stats_from_atomic(self.output, self.atype)
147176
np.testing.assert_allclose(bias, self.mean)

0 commit comments

Comments
 (0)