Skip to content

Commit c15ac46

Browse files
committed
Improve OpenMX cell parsing robustness and SCF convergence handling
1 parent e683cfd commit c15ac46

1 file changed

Lines changed: 22 additions & 42 deletions

File tree

dpdata/openmx/omx.py

Lines changed: 22 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,6 @@
2727
import warnings
2828
from collections import OrderedDict
2929

30-
### iterout.c from OpenMX soure code: column numbers and physical quantities ###
31-
# /* 1: */
32-
# /* 2,3,4: */
33-
# /* 5,6,7: force *
34-
# /* 8: x-component of velocity */
35-
# /* 9: y-component of velocity */
36-
# /* 10: z-component of velocity */
37-
# /* 11: Net charge, electron charge is defined to be negative. */
38-
# /* 12: magnetic moment (muB) */
39-
# /* 13,14: angles of spin */
40-
41-
# 15: scf_convergence_flag (optional)
42-
#
43-
# 1. Move the declaration of `scf_convergence_flag` in `DFT.c` to `openmx_common.h`.
44-
# 2. Add `scf_convergence_flag` output to the end of `iterout.c` where `*.md` is written.
45-
# 3. Recompile OpenMX.
46-
4730

4831
def load_atom(lines):
4932
atom_names = []
@@ -56,9 +39,8 @@ def load_atom(lines):
5639
elif atom_names_mode:
5740
parts = line.split()
5841
atom_names.append(parts[1])
59-
natoms = len(atom_names)
6042
atom_names_original = atom_names
61-
atom_names = list(OrderedDict.fromkeys(set(atom_names))) # Python>=3.7
43+
atom_names = list(OrderedDict.fromkeys(set(atom_names)))
6244
atom_names = sorted(
6345
atom_names, key=atom_names_original.index
6446
) # Unique ordering of atomic species
@@ -82,24 +64,25 @@ def load_atom(lines):
8264

8365

8466
def load_cells(lines):
85-
cell, cells = [], []
86-
for index, line in enumerate(lines):
67+
cells = []
68+
for line in lines:
8769
if "Cell_Vectors=" in line:
88-
parts = line.split()
89-
if len(parts) == 21: # MD.Type is NVT_NH
90-
cell.append([float(parts[12]), float(parts[13]), float(parts[14])])
91-
cell.append([float(parts[15]), float(parts[16]), float(parts[17])])
92-
cell.append([float(parts[18]), float(parts[19]), float(parts[20])])
93-
elif len(parts) == 16: # MD.Type is Opt
94-
cell.append([float(parts[7]), float(parts[8]), float(parts[9])])
95-
cell.append([float(parts[10]), float(parts[11]), float(parts[12])])
96-
cell.append([float(parts[13]), float(parts[14]), float(parts[15])])
97-
else:
98-
raise RuntimeError(
99-
"Does the file System.Name.md contain unsupported calculation results?"
100-
)
70+
part = line.split("Cell_Vectors=")[1]
71+
parts = part.split()
72+
if len(parts) < 9:
73+
raise RuntimeError("Cell_Vectors does not contain enough elements.")
74+
values = list(map(float, parts[:9]))
75+
cell = [values[0:3], values[3:6], values[6:9]]
10176
cells.append(cell)
102-
cell = []
77+
# Checking SCF converged or not
78+
for token in line.split():
79+
if token.startswith("scf_conv="):
80+
try:
81+
scf_conv = int(token.split("=")[1])
82+
if scf_conv == 0:
83+
warnings.warn(f"SCF not converged!")
84+
except (IndexError, ValueError):
85+
pass
10386
cells = np.array(cells)
10487
return cells
10588

@@ -119,7 +102,7 @@ def load_param_file(fname: FileType, mdname: FileType):
119102
def load_coords(lines, atom_names, natoms):
120103
cnt = 0
121104
coord, coords = [], []
122-
for index, line in enumerate(lines):
105+
for line in lines:
123106
if "time=" in line:
124107
continue
125108
for atom_name in atom_names:
@@ -129,9 +112,6 @@ def load_coords(lines, atom_names, natoms):
129112
parts = line.split()
130113
for_line = [float(parts[1]), float(parts[2]), float(parts[3])]
131114
coord.append(for_line)
132-
# It may be necessary to recompile OpenMX to make scf convergence determination.
133-
if len(parts) == 15 and parts[14] == "0":
134-
warnings.warn("SCF in System.Name.md has not converged!")
135115
if cnt == natoms:
136116
coords.append(coord)
137117
cnt = 0
@@ -180,7 +160,7 @@ def load_energy(lines):
180160
def load_force(lines, atom_names, atom_numbs):
181161
cnt = 0
182162
field, fields = [], []
183-
for index, line in enumerate(lines):
163+
for line in lines:
184164
if "time=" in line:
185165
continue
186166
for atom_name in atom_names:
@@ -209,7 +189,7 @@ def to_system_label(fname, mdname):
209189

210190

211191
if __name__ == "__main__":
212-
file_name = "Cdia"
192+
file_name = "Au111Surface"
213193
fname = f"{file_name}.dat"
214194
mdname = f"{file_name}.md"
215195
atom_names, atom_numbs, atom_types, cells = load_param_file(fname, mdname)
@@ -222,4 +202,4 @@ def to_system_label(fname, mdname):
222202
# print(cells.shape)
223203
# print(coords.shape)
224204
# print(len(energy))
225-
# print(force.shape)
205+
# print(force.shape)

0 commit comments

Comments
 (0)