Skip to content

Commit aa386a6

Browse files
Support dpmodel and jax
1 parent 6a66041 commit aa386a6

7 files changed

Lines changed: 297 additions & 69 deletions

File tree

deepmd/dpmodel/fitting/property_fitting.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77
import numpy as np
88

9+
from deepmd.dpmodel.output_def import (
10+
FittingOutputDef,
11+
OutputVariableDef,
12+
)
913
from deepmd.dpmodel.common import (
1014
DEFAULT_PRECISION,
1115
)
@@ -108,6 +112,20 @@ def __init__(
108112
type_map=type_map,
109113
)
110114

115+
def output_def(self) -> FittingOutputDef:
116+
return FittingOutputDef(
117+
[
118+
OutputVariableDef(
119+
self.var_name,
120+
[self.dim_out],
121+
reducible=True,
122+
r_differentiable=False,
123+
c_differentiable=False,
124+
intensive=self.intensive,
125+
),
126+
]
127+
)
128+
111129
@classmethod
112130
def deserialize(cls, data: dict) -> "PropertyFittingNet":
113131
data = data.copy()

deepmd/dpmodel/model/make_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def forward_common_atomic(
355355
self.atomic_output_def(),
356356
extended_coord,
357357
do_atomic_virial=do_atomic_virial,
358+
mask=atomic_ret["mask"] if "mask" in atomic_ret else None,
358359
)
359360

360361
forward_lower = call_lower

deepmd/dpmodel/model/transform_output.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
get_hessian_name,
1818
get_reduce_name,
1919
)
20+
from typing import Optional
2021

2122

2223
def fit_output_to_model_output(
2324
fit_ret: dict[str, np.ndarray],
2425
fit_output_def: FittingOutputDef,
2526
coord_ext: np.ndarray,
2627
do_atomic_virial: bool = False,
28+
mask: Optional[np.ndarray] = None
2729
) -> dict[str, np.ndarray]:
2830
"""Transform the output of the fitting network to
2931
the model output.
@@ -38,9 +40,21 @@ def fit_output_to_model_output(
3840
if vdef.reducible:
3941
kk_redu = get_reduce_name(kk)
4042
# cast to energy prec before reduction
41-
model_ret[kk_redu] = xp.sum(
42-
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
43-
)
43+
if vdef.intensive:
44+
if (mask is not None) and (mask == 0.0).any():
45+
mask = mask.astype(bool)
46+
model_ret[kk_redu] = xp.stack([
47+
xp.mean(vv[ii].astype(GLOBAL_ENER_FLOAT_PRECISION)[mask[ii]], axis=atom_axis)
48+
for ii in range(mask.shape[0])
49+
])
50+
else:
51+
model_ret[kk_redu] = xp.mean(
52+
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
53+
)
54+
else:
55+
model_ret[kk_redu] = xp.sum(
56+
vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis
57+
)
4458
if vdef.r_differentiable:
4559
kk_derv_r, kk_derv_c = get_deriv_name(kk)
4660
# name-holders

deepmd/jax/model/base_model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,18 @@ def forward_common_atomic(
4646
atom_axis = -(len(shap) + 1)
4747
if vdef.reducible:
4848
kk_redu = get_reduce_name(kk)
49-
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
49+
if vdef.intensive:
50+
mask = atomic_ret["mask"] if "mask" in atomic_ret else None
51+
if (mask is not None) and (mask==0.0).any():
52+
mask = mask.astype(jnp.bool_)
53+
model_predict[kk_redu] = jnp.stack([
54+
jnp.mean(vv[ii][mask[ii]], axis=atom_axis)
55+
for ii in range(mask.shape[0])
56+
])
57+
else:
58+
model_predict[kk_redu] = jnp.mean(vv, axis=atom_axis)
59+
else:
60+
model_predict[kk_redu] = jnp.sum(vv, axis=atom_axis)
5061
kk_derv_r, kk_derv_c = get_deriv_name(kk)
5162
if vdef.r_differentiable:
5263

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import sys
3+
import unittest
4+
5+
import numpy as np
6+
from copy import deepcopy
7+
8+
from deepmd.dpmodel.descriptor.se_e2_a import (
9+
DescrptSeA,
10+
)
11+
from deepmd.dpmodel.fitting import (
12+
PropertyFittingNet,
13+
)
14+
from deepmd.dpmodel.model.property_model import (
15+
PropertyModel,
16+
)
17+
18+
19+
class TestCaseSingleFrameWithoutNlist:
20+
def setUp(self) -> None:
21+
# nf=2, nloc == 3
22+
self.nloc = 3
23+
self.nt = 2
24+
self.coord = np.array(
25+
[
26+
[
27+
[0, 0, 0],
28+
[0, 1, 0],
29+
[0, 0, 1],
30+
],
31+
[
32+
[1, 0, 1],
33+
[0, 1, 1],
34+
[1, 1, 0],
35+
]
36+
],
37+
dtype=np.float64,
38+
)
39+
self.atype = np.array([[0, 0, 1],[1, 1, 0]], dtype=int).reshape([2, self.nloc])
40+
self.cell = 2.0 * np.eye(3).reshape([1, 9])
41+
self.cell = np.array([self.cell,self.cell]).reshape(2, 9)
42+
self.sel = [16, 8]
43+
self.rcut = 2.2
44+
self.rcut_smth = 0.4
45+
self.atol = 1e-12
46+
47+
48+
class TestPaddingAtoms(unittest.TestCase, TestCaseSingleFrameWithoutNlist):
49+
def setUp(self):
50+
TestCaseSingleFrameWithoutNlist.setUp(self)
51+
52+
def test_padding_atoms_consistency(self):
53+
ds = DescrptSeA(
54+
self.rcut,
55+
self.rcut_smth,
56+
self.sel,
57+
)
58+
ft = PropertyFittingNet(
59+
self.nt,
60+
ds.get_dim_out(),
61+
mixed_types=ds.mixed_types(),
62+
intensive=True,
63+
)
64+
type_map = ["foo", "bar"]
65+
model = PropertyModel(ds, ft, type_map=type_map)
66+
var_name = model.get_var_name()
67+
args = [self.coord, self.atype, self.cell]
68+
result = model.call(*args)
69+
# test intensive
70+
np.testing.assert_allclose(
71+
result[f"{var_name}_redu"],
72+
np.mean(result[f"{var_name}"],axis=1),
73+
atol=self.atol,
74+
)
75+
# test padding atoms
76+
padding_atoms_list = [1, 5, 10]
77+
for padding_atoms in padding_atoms_list:
78+
coord = deepcopy(self.coord)
79+
atype = deepcopy(self.atype)
80+
atype_padding = np.pad(
81+
atype,
82+
pad_width=((0, 0), (0, padding_atoms)),
83+
mode='constant',
84+
constant_values=-1
85+
)
86+
coord_padding = np.pad(
87+
coord,
88+
pad_width=((0, 0), (0, padding_atoms), (0, 0)),
89+
mode='constant',
90+
constant_values=0
91+
)
92+
args = [coord_padding, atype_padding, self.cell]
93+
result_padding = model.call(*args)
94+
np.testing.assert_allclose(
95+
result[f"{var_name}_redu"],
96+
result_padding[f"{var_name}_redu"],
97+
atol=self.atol,
98+
)
99+
100+
101+
if __name__ == "__main__":
102+
unittest.main()

source/tests/jax/test_padding_atoms.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,36 +29,43 @@
2929
dtype = jnp.float64
3030

3131

32-
#@unittest.skipIf(
33-
# sys.version_info < (3, 10),
34-
# "JAX requires Python 3.10 or later",
35-
#)
32+
@unittest.skipIf(
33+
sys.version_info < (3, 10),
34+
"JAX requires Python 3.10 or later",
35+
)
3636
class TestCaseSingleFrameWithoutNlist:
3737
def setUp(self) -> None:
38-
# nloc == 3, nall == 4
38+
# nf=2, nloc == 3
3939
self.nloc = 3
40-
self.nf, self.nt = 1, 2
40+
self.nt = 2
4141
self.coord = np.array(
4242
[
43-
[0, 0, 0],
44-
[0, 1, 0],
45-
[0, 0, 1],
43+
[
44+
[0, 0, 0],
45+
[0, 1, 0],
46+
[0, 0, 1],
47+
],
48+
[
49+
[1, 0, 1],
50+
[0, 1, 1],
51+
[1, 1, 0],
52+
]
4653
],
4754
dtype=np.float64,
48-
).reshape([1, self.nloc * 3])
49-
self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc])
55+
)
56+
self.atype = np.array([[0, 0, 1],[1, 1, 0]], dtype=int).reshape([2, self.nloc])
5057
self.cell = 2.0 * np.eye(3).reshape([1, 9])
51-
# sel = [5, 2]
58+
self.cell = np.array([self.cell,self.cell]).reshape(2, 9)
5259
self.sel = [16, 8]
5360
self.rcut = 2.2
5461
self.rcut_smth = 0.4
5562
self.atol = 1e-12
5663

5764

58-
#@unittest.skipIf(
59-
# sys.version_info < (3, 10),
60-
# "JAX requires Python 3.10 or later",
61-
#)
65+
@unittest.skipIf(
66+
sys.version_info < (3, 10),
67+
"JAX requires Python 3.10 or later",
68+
)
6269
class TestPaddingAtoms(unittest.TestCase, TestCaseSingleFrameWithoutNlist):
6370
def setUp(self):
6471
TestCaseSingleFrameWithoutNlist.setUp(self)
@@ -77,12 +84,40 @@ def test_padding_atoms_consistency(self):
7784
)
7885
type_map = ["foo", "bar"]
7986
model = PropertyModel(ds, ft, type_map=type_map)
87+
var_name = model.get_var_name()
8088
args = [to_jax_array(ii) for ii in [self.coord, self.atype, self.cell]]
81-
ret_base = model.call(*args)
89+
result = model.call(*args)
90+
# test intensive
91+
np.testing.assert_allclose(
92+
to_numpy_array(result[f"{var_name}_redu"]),
93+
np.mean(to_numpy_array(result[f"{var_name}"]),axis=1),
94+
atol=self.atol,
95+
)
96+
# test padding atoms
97+
padding_atoms_list = [1, 5, 10]
98+
for padding_atoms in padding_atoms_list:
99+
coord = deepcopy(self.coord)
100+
atype = deepcopy(self.atype)
101+
atype_padding = np.pad(
102+
atype,
103+
pad_width=((0, 0), (0, padding_atoms)),
104+
mode='constant',
105+
constant_values=-1
106+
)
107+
coord_padding = np.pad(
108+
coord,
109+
pad_width=((0, 0), (0, padding_atoms), (0, 0)),
110+
mode='constant',
111+
constant_values=0
112+
)
113+
args = [to_jax_array(ii) for ii in [coord_padding, atype_padding, self.cell]]
114+
result_padding = model.call(*args)
115+
np.testing.assert_allclose(
116+
to_numpy_array(result[f"{var_name}_redu"]),
117+
to_numpy_array(result_padding[f"{var_name}_redu"]),
118+
atol=self.atol,
119+
)
82120

83121

84-
#np.testing.assert_allclose(
85-
# to_numpy_array(ret0[model.get_var_name()]),
86-
# to_numpy_array(ret1[md1.get_var_name()]),
87-
# atol=self.atol,
88-
#)
122+
if __name__ == "__main__":
123+
unittest.main()

0 commit comments

Comments
 (0)