Skip to content

Commit 36fa629

Browse files
committed
Rename and update OpenMX test to reflect multiple format support
1 parent c15ac46 commit 36fa629

2 files changed

Lines changed: 60 additions & 64 deletions

File tree

tests/test_openmx_check_convergence.py

Lines changed: 0 additions & 64 deletions
This file was deleted.
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
5+
import numpy as np
6+
from context import dpdata
7+
8+
9+
class TestOPENMXTRAJProps:
10+
def test_atom_names(self):
11+
self.assertEqual(self.system.data["atom_names"], ["Au", "H"])
12+
13+
def test_atom_numbs(self):
14+
self.assertEqual(self.system.data["atom_numbs"], [3, 1])
15+
16+
def test_atom_types(self):
17+
self.assertEqual(list(self.system.data["atom_types"]), [0, 0, 0, 1])
18+
19+
def test_cell(self):
20+
cell = np.array([
21+
[211.86272, 0.0, 0.0],
22+
[0.0, 2.88309, 0.0],
23+
[0.0, -1.44154, 2.49683]
24+
])
25+
cells = np.array([cell, cell])
26+
self.assertEqual(self.system.get_nframes(), 2)
27+
for ff in range(self.system.get_nframes()):
28+
for ii in range(3):
29+
for jj in range(3):
30+
self.assertAlmostEqual(
31+
self.system["cells"][ff][ii][jj], cells[ii][jj]
32+
)
33+
34+
def test_coord(self):
35+
coord = np.array([
36+
[21.18627, 0.0, 1.66455],
37+
[23.5403, 1.44154, 0.83228],
38+
[25.89433, 0.0, 0.0],
39+
[28.24836, 0.0, 1.66455]
40+
])
41+
coords = np.array([coord, coord])
42+
for ff in range(self.system.get_nframes()):
43+
for ii in range(4):
44+
for jj in range(3):
45+
self.assertAlmostEqual(
46+
self.system["coords"][ff][ii][jj], coords[ff][ii][jj]
47+
)
48+
49+
class TestOPENMXTraj(unittest.TestCase, TestOPENMXTRAJProps):
50+
def setUp(self):
51+
self.system = dpdata.System("openmx/Au111Surface", fmt="openmx/md")
52+
53+
54+
class TestOPENMXLabeledTraj(unittest.TestCase, TestOPENMXTRAJProps):
55+
def setUp(self):
56+
self.system = dpdata.LabeledSystem("openmx/Au111Surface", fmt="openmx/md")
57+
58+
59+
if __name__ == "__main__":
60+
unittest.main()

0 commit comments

Comments
 (0)