Skip to content

Commit ee00f96

Browse files
Chengqian-Zhangpre-commit-ci[bot]
authored andcommitted
fix: get correct intensive property prediction when using virtual atoms (deepmodeling#4869)
When using virtual atoms, the property output of virtual atom is `0`. - If predicting energy or other extensive properties, it works well, that's because the virtual atom property `0` do not contribute to the total energy or other extensive properties. - However, if predicting intensive properties, there is some error. For example, a frame has two real atoms and two virtual atoms, the atomic property contribution is [2, 2, 0, 0](the atomic property of virtual atoms are always 0), the final property should be `(2+2)/real_atoms = 2`, not be `(2+2)/total_atoms =1`. This PR is used to solve this bug mentioned above. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Models now provide accessors to retrieve property names and their fitting network; property fitting nets expose output definitions. * **Bug Fixes** * Intensive property reduction respects atom masks so padded/dummy atoms are ignored, keeping results invariant to padding. * **Tests** * Added PyTorch, JAX, and core tests validating consistent behavior with padded atoms. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 800a3dc commit ee00f96

11 files changed

Lines changed: 405 additions & 5 deletions

File tree

deepmd/dpmodel/fitting/property_fitting.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from deepmd.dpmodel.fitting.invar_fitting import (
1313
InvarFitting,
1414
)
15+
from deepmd.dpmodel.output_def import (
16+
FittingOutputDef,
17+
OutputVariableDef,
18+
)
1519
from deepmd.utils.version import (
1620
check_version_compatibility,
1721
)
@@ -108,6 +112,20 @@ def __init__(
108112
type_map=type_map,
109113
)
110114

115+
def output_def(self) -> FittingOutputDef:
116+
return FittingOutputDef(
117+
[
118+
OutputVariableDef(
119+
self.var_name,
120+
[self.dim_out],
121+
reducible=True,
122+
r_differentiable=False,
123+
c_differentiable=False,
124+
intensive=self.intensive,
125+
),
126+
]
127+
)
128+
111129
@classmethod
112130
def deserialize(cls, data: dict) -> "PropertyFittingNet":
113131
data = data.copy()

deepmd/dpmodel/model/dp_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ def update_sel(
4545
train_data, type_map, local_jdata["descriptor"]
4646
)
4747
return local_jdata_cpy, min_nbor_dist
48+
49+
def get_fitting_net(self):
50+
"""Get the fitting network."""
51+
return self.atomic_model.fitting

deepmd/dpmodel/model/make_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def forward_common_atomic(
355355
self.atomic_output_def(),
356356
extended_coord,
357357
do_atomic_virial=do_atomic_virial,
358+
mask=atomic_ret["mask"] if "mask" in atomic_ret else None,
358359
)
359360

360361
forward_lower = call_lower

deepmd/dpmodel/model/property_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ def __init__(
2525
) -> None:
2626
DPModelCommon.__init__(self)
2727
DPPropertyModel_.__init__(self, *args, **kwargs)
28+
29+
def get_var_name(self) -> str:
30+
"""Get the name of the property."""
31+
return self.get_fitting_net().var_name

deepmd/dpmodel/model/transform_output.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3+
from typing import (
4+
Optional,
5+
)
6+
37
import array_api_compat
48
import numpy as np
59

@@ -24,6 +28,7 @@ def fit_output_to_model_output(
2428
fit_output_def: FittingOutputDef,
2529
coord_ext: np.ndarray,
2630
do_atomic_virial: bool = False,
31+
mask: Optional[np.ndarray] = None,
2732
) -> dict[str, np.ndarray]:
2833
"""Transform the output of the fitting network to
2934
the model output.
@@ -38,9 +43,19 @@ def fit_output_to_model_output(
3843
if vdef.reducible:
3944
kk_redu = get_reduce_name(kk)
4045
# cast to energy prec before reduction
41-
model_ret[kk_redu] = xp.sum(
42-
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
43-
)
46+
if vdef.intensive:
47+
if mask is not None:
48+
model_ret[kk_redu] = xp.sum(
49+
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
50+
) / np.sum(mask, axis=-1, keepdims=True)
51+
else:
52+
model_ret[kk_redu] = xp.mean(
53+
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
54+
)
55+
else:
56+
model_ret[kk_redu] = xp.sum(
57+
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
58+
)
4459
if vdef.r_differentiable:
4560
kk_derv_r, kk_derv_c = get_deriv_name(kk)
4661
# name-holders

deepmd/jax/model/base_model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,16 @@ def forward_common_atomic(
4646
atom_axis = -(len(shap) + 1)
4747
if vdef.reducible:
4848
kk_redu = get_reduce_name(kk)
49-
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
49+
if vdef.intensive:
50+
mask = atomic_ret["mask"] if "mask" in atomic_ret else None
51+
if mask is not None:
52+
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis) / jnp.sum(
53+
mask, axis=-1, keepdims=True
54+
)
55+
else:
56+
model_predict[kk_redu] = jnp.mean(vv, axis=atom_axis)
57+
else:
58+
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
5059
kk_derv_r, kk_derv_c = get_deriv_name(kk)
5160
if vdef.r_differentiable:
5261

deepmd/pt/model/model/make_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def forward_common_lower(
299299
cc_ext,
300300
do_atomic_virial=do_atomic_virial,
301301
create_graph=self.training,
302+
mask=atomic_ret["mask"] if "mask" in atomic_ret else None,
302303
)
303304
model_predict = self.output_type_cast(model_predict, input_prec)
304305
return model_predict

deepmd/pt/model/model/transform_output.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def fit_output_to_model_output(
158158
coord_ext: torch.Tensor,
159159
do_atomic_virial: bool = False,
160160
create_graph: bool = True,
161+
mask: Optional[torch.Tensor] = None,
161162
) -> dict[str, torch.Tensor]:
162163
"""Transform the output of the fitting network to
163164
the model output.
@@ -172,7 +173,12 @@ def fit_output_to_model_output(
172173
if vdef.reducible:
173174
kk_redu = get_reduce_name(kk)
174175
if vdef.intensive:
175-
model_ret[kk_redu] = torch.mean(vv.to(redu_prec), dim=atom_axis)
176+
if mask is not None:
177+
model_ret[kk_redu] = torch.sum(
178+
vv.to(redu_prec), dim=atom_axis
179+
) / torch.sum(mask, dim=-1, keepdim=True)
180+
else:
181+
model_ret[kk_redu] = torch.mean(vv.to(redu_prec), dim=atom_axis)
176182
else:
177183
model_ret[kk_redu] = torch.sum(vv.to(redu_prec), dim=atom_axis)
178184
if vdef.r_differentiable:
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import unittest
3+
from copy import (
4+
deepcopy,
5+
)
6+
7+
import numpy as np
8+
9+
from deepmd.dpmodel.descriptor.se_e2_a import (
10+
DescrptSeA,
11+
)
12+
from deepmd.dpmodel.fitting import (
13+
PropertyFittingNet,
14+
)
15+
from deepmd.dpmodel.model.property_model import (
16+
PropertyModel,
17+
)
18+
19+
20+
class TestCaseSingleFrameWithoutNlist:
21+
def setUp(self) -> None:
22+
# nf=2, nloc == 3
23+
self.nloc = 3
24+
self.nt = 2
25+
self.coord = np.array(
26+
[
27+
[
28+
[0, 0, 0],
29+
[0, 1, 0],
30+
[0, 0, 1],
31+
],
32+
[
33+
[1, 0, 1],
34+
[0, 1, 1],
35+
[1, 1, 0],
36+
],
37+
],
38+
dtype=np.float64,
39+
)
40+
self.atype = np.array([[0, 0, 1], [1, 1, 0]], dtype=int).reshape([2, self.nloc])
41+
self.cell = 2.0 * np.eye(3).reshape([1, 9])
42+
self.cell = np.array([self.cell, self.cell]).reshape(2, 9)
43+
self.sel = [16, 8]
44+
self.rcut = 2.2
45+
self.rcut_smth = 0.4
46+
self.atol = 1e-12
47+
48+
49+
class TestPaddingAtoms(unittest.TestCase, TestCaseSingleFrameWithoutNlist):
50+
def setUp(self):
51+
TestCaseSingleFrameWithoutNlist.setUp(self)
52+
53+
def test_padding_atoms_consistency(self):
54+
ds = DescrptSeA(
55+
self.rcut,
56+
self.rcut_smth,
57+
self.sel,
58+
)
59+
ft = PropertyFittingNet(
60+
self.nt,
61+
ds.get_dim_out(),
62+
mixed_types=ds.mixed_types(),
63+
intensive=True,
64+
)
65+
type_map = ["foo", "bar"]
66+
model = PropertyModel(ds, ft, type_map=type_map)
67+
var_name = model.get_var_name()
68+
args = [self.coord, self.atype, self.cell]
69+
result = model.call(*args)
70+
# test intensive
71+
np.testing.assert_allclose(
72+
result[f"{var_name}_redu"],
73+
np.mean(result[f"{var_name}"], axis=1),
74+
atol=self.atol,
75+
)
76+
# test padding atoms
77+
padding_atoms_list = [1, 5, 10]
78+
for padding_atoms in padding_atoms_list:
79+
coord = deepcopy(self.coord)
80+
atype = deepcopy(self.atype)
81+
atype_padding = np.pad(
82+
atype,
83+
pad_width=((0, 0), (0, padding_atoms)),
84+
mode="constant",
85+
constant_values=-1,
86+
)
87+
coord_padding = np.pad(
88+
coord,
89+
pad_width=((0, 0), (0, padding_atoms), (0, 0)),
90+
mode="constant",
91+
constant_values=0,
92+
)
93+
args = [coord_padding, atype_padding, self.cell]
94+
result_padding = model.call(*args)
95+
np.testing.assert_allclose(
96+
result[f"{var_name}_redu"],
97+
result_padding[f"{var_name}_redu"],
98+
atol=self.atol,
99+
)
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main()
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import sys
3+
import unittest
4+
from copy import (
5+
deepcopy,
6+
)
7+
8+
import numpy as np
9+
10+
from deepmd.dpmodel.common import (
11+
to_numpy_array,
12+
)
13+
14+
if sys.version_info >= (3, 10):
15+
from deepmd.jax.common import (
16+
to_jax_array,
17+
)
18+
from deepmd.jax.descriptor.se_e2_a import (
19+
DescrptSeA,
20+
)
21+
from deepmd.jax.env import (
22+
jnp,
23+
)
24+
from deepmd.jax.fitting.fitting import (
25+
PropertyFittingNet,
26+
)
27+
from deepmd.jax.model.property_model import (
28+
PropertyModel,
29+
)
30+
31+
dtype = jnp.float64
32+
33+
34+
@unittest.skipIf(
35+
sys.version_info < (3, 10),
36+
"JAX requires Python 3.10 or later",
37+
)
38+
class TestCaseSingleFrameWithoutNlist:
39+
def setUp(self) -> None:
40+
# nf=2, nloc == 3
41+
self.nloc = 3
42+
self.nt = 2
43+
self.coord = np.array(
44+
[
45+
[
46+
[0, 0, 0],
47+
[0, 1, 0],
48+
[0, 0, 1],
49+
],
50+
[
51+
[1, 0, 1],
52+
[0, 1, 1],
53+
[1, 1, 0],
54+
],
55+
],
56+
dtype=np.float64,
57+
)
58+
self.atype = np.array([[0, 0, 1], [1, 1, 0]], dtype=int).reshape([2, self.nloc])
59+
self.cell = 2.0 * np.eye(3).reshape([1, 9])
60+
self.cell = np.array([self.cell, self.cell]).reshape(2, 9)
61+
self.sel = [16, 8]
62+
self.rcut = 2.2
63+
self.rcut_smth = 0.4
64+
self.atol = 1e-12
65+
66+
67+
@unittest.skipIf(
68+
sys.version_info < (3, 10),
69+
"JAX requires Python 3.10 or later",
70+
)
71+
class TestPaddingAtoms(unittest.TestCase, TestCaseSingleFrameWithoutNlist):
72+
def setUp(self):
73+
TestCaseSingleFrameWithoutNlist.setUp(self)
74+
75+
def test_padding_atoms_consistency(self):
76+
ds = DescrptSeA(
77+
self.rcut,
78+
self.rcut_smth,
79+
self.sel,
80+
)
81+
ft = PropertyFittingNet(
82+
self.nt,
83+
ds.get_dim_out(),
84+
mixed_types=ds.mixed_types(),
85+
intensive=True,
86+
)
87+
type_map = ["foo", "bar"]
88+
model = PropertyModel(ds, ft, type_map=type_map)
89+
var_name = model.get_var_name()
90+
args = [to_jax_array(ii) for ii in [self.coord, self.atype, self.cell]]
91+
result = model.call(*args)
92+
# test intensive
93+
np.testing.assert_allclose(
94+
to_numpy_array(result[f"{var_name}_redu"]),
95+
np.mean(to_numpy_array(result[f"{var_name}"]), axis=1),
96+
atol=self.atol,
97+
)
98+
# test padding atoms
99+
padding_atoms_list = [1, 5, 10]
100+
for padding_atoms in padding_atoms_list:
101+
coord = deepcopy(self.coord)
102+
atype = deepcopy(self.atype)
103+
atype_padding = np.pad(
104+
atype,
105+
pad_width=((0, 0), (0, padding_atoms)),
106+
mode="constant",
107+
constant_values=-1,
108+
)
109+
coord_padding = np.pad(
110+
coord,
111+
pad_width=((0, 0), (0, padding_atoms), (0, 0)),
112+
mode="constant",
113+
constant_values=0,
114+
)
115+
args = [
116+
to_jax_array(ii) for ii in [coord_padding, atype_padding, self.cell]
117+
]
118+
result_padding = model.call(*args)
119+
np.testing.assert_allclose(
120+
to_numpy_array(result[f"{var_name}_redu"]),
121+
to_numpy_array(result_padding[f"{var_name}_redu"]),
122+
atol=self.atol,
123+
)
124+
125+
126+
if __name__ == "__main__":
127+
unittest.main()

0 commit comments

Comments
 (0)