Skip to content

Commit 2e2f564

Browse files
committed
Create test_lammps_faparam_pt.py
1 parent d76ef7f commit 2e2f564

1 file changed

Lines changed: 134 additions & 0 deletions

File tree

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Test LAMMPS with default_fparam (PyTorch backend)."""
3+
4+
import os
5+
from pathlib import (
6+
Path,
7+
)
8+
9+
import numpy as np
10+
import pytest
11+
from lammps import (
12+
PyLammps,
13+
)
14+
from write_lmp_data import (
15+
write_lmp_data,
16+
)
17+
18+
pth_file = (
19+
Path(__file__).parent.parent.parent
20+
/ "tests"
21+
/ "infer"
22+
/ "fparam_aparam_default.pth"
23+
)
24+
data_file = Path(__file__).parent / "data.lmp"
25+
26+
# expected values from fparam_aparam_default.pth with default_fparam=[0.25852028]
27+
expected_ae = np.array(
28+
[
29+
-1.038271223729637e-01,
30+
-7.285433579124989e-02,
31+
-9.467600492266426e-02,
32+
-1.467050207422953e-01,
33+
-7.660561676973243e-02,
34+
-7.277296000253175e-02,
35+
]
36+
)
37+
expected_e = np.sum(expected_ae)
38+
expected_f = np.array(
39+
[
40+
6.622266941151356e-02,
41+
5.278739714221517e-02,
42+
2.265728009692279e-02,
43+
-2.606048291367521e-02,
44+
-4.538812303131843e-02,
45+
1.058247419681242e-02,
46+
1.679392617013225e-01,
47+
-2.257826240741907e-03,
48+
-4.490146347357200e-02,
49+
-1.148364179422036e-01,
50+
-1.169790528013792e-02,
51+
6.140403441496690e-02,
52+
-8.078778123309406e-02,
53+
-5.838879041789346e-02,
54+
6.773641084621368e-02,
55+
-1.247724902386317e-02,
56+
6.494524782787654e-02,
57+
-1.174787360813438e-01,
58+
]
59+
).reshape(6, 3)
60+
61+
box = np.array([0, 13, 0, 13, 0, 13, 0, 0, 0])
62+
coord = np.array(
63+
[
64+
[12.83, 2.56, 2.18],
65+
[12.09, 2.87, 2.74],
66+
[0.25, 3.32, 1.68],
67+
[3.36, 3.00, 1.81],
68+
[3.51, 2.51, 2.60],
69+
[4.27, 3.22, 1.56],
70+
]
71+
)
72+
type_OH = np.array([1, 1, 1, 1, 1, 1])
73+
74+
75+
def setup_module() -> None:
76+
if os.environ.get("ENABLE_PYTORCH", "1") != "1":
77+
pytest.skip(
78+
"Skip test because PyTorch support is not enabled.",
79+
)
80+
write_lmp_data(box, coord, type_OH, data_file)
81+
82+
83+
def teardown_module() -> None:
84+
if data_file.exists():
85+
os.remove(data_file)
86+
87+
88+
def _lammps(data_file, units="metal") -> PyLammps:
89+
lammps = PyLammps()
90+
lammps.units(units)
91+
lammps.boundary("p p p")
92+
lammps.atom_style("atomic")
93+
lammps.neighbor("2.0 bin")
94+
lammps.neigh_modify("every 10 delay 0 check no")
95+
lammps.read_data(data_file.resolve())
96+
lammps.mass("1 16")
97+
lammps.timestep(0.0005)
98+
lammps.fix("1 all nve")
99+
return lammps
100+
101+
102+
@pytest.fixture
103+
def lammps():
104+
lmp = _lammps(data_file=data_file)
105+
yield lmp
106+
lmp.close()
107+
108+
109+
def test_pair_deepmd_default_fparam(lammps) -> None:
110+
"""Test that model with default_fparam works without providing fparam."""
111+
lammps.pair_style(f"deepmd {pth_file.resolve()} aparam 0.25852028")
112+
lammps.pair_coeff("* *")
113+
lammps.run(0)
114+
assert lammps.eval("pe") == pytest.approx(expected_e)
115+
for ii in range(6):
116+
assert lammps.atoms[ii].force == pytest.approx(
117+
expected_f[lammps.atoms[ii].id - 1]
118+
)
119+
lammps.run(1)
120+
121+
122+
def test_pair_deepmd_default_fparam_explicit(lammps) -> None:
123+
"""Test that explicit fparam still works with default_fparam model."""
124+
lammps.pair_style(
125+
f"deepmd {pth_file.resolve()} fparam 0.25852028 aparam 0.25852028"
126+
)
127+
lammps.pair_coeff("* *")
128+
lammps.run(0)
129+
assert lammps.eval("pe") == pytest.approx(expected_e)
130+
for ii in range(6):
131+
assert lammps.atoms[ii].force == pytest.approx(
132+
expected_f[lammps.atoms[ii].id - 1]
133+
)
134+
lammps.run(1)

0 commit comments

Comments
 (0)