Skip to content

Commit 58c4bb6

Browse files
committed
feat(lammps): generate masses for lammps/lmp export
Prefer explicitly stored masses when available, and otherwise infer per-type masses from atom_names when all names are valid element symbols. Keep the previous behavior for unknown type names so exports do not emit unsafe Masses sections. Add regression tests covering both known-element and unknown-type cases, and make the new tests independent of the current working directory. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4)
1 parent 1e79653 commit 58c4bb6

2 files changed

Lines changed: 69 additions & 0 deletions

File tree

dpdata/lammps/lmp.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import numpy as np
55

6+
from dpdata.periodic_table import Element, ELEMENTS
7+
68
ptr_float_fmt = "%15.10f"
79
ptr_int_fmt = "%6d"
810
ptr_key_fmt = "%15s"
@@ -484,6 +486,34 @@ def rotate_to_lower_triangle(
484486
return cell, coord
485487

486488

489+
def _get_lammps_masses(system) -> np.ndarray | None:
490+
"""Get masses for the LAMMPS ``Masses`` section.
491+
492+
Prefer explicitly stored masses when available. Otherwise, infer masses from
493+
``atom_names`` when all names are valid chemical element symbols.
494+
495+
Parameters
496+
----------
497+
system : dict
498+
System data dictionary
499+
500+
Returns
501+
-------
502+
np.ndarray or None
503+
Per-type masses aligned with ``atom_names``. Returns ``None`` when the
504+
masses cannot be determined safely.
505+
"""
506+
atom_names = system["atom_names"]
507+
masses = system.get("masses")
508+
if masses is not None and len(masses) >= len(atom_names):
509+
return np.asarray(masses[: len(atom_names)], dtype=float)
510+
511+
if not all(name in ELEMENTS for name in atom_names):
512+
return None
513+
514+
return np.array([Element(name).mass for name in atom_names], dtype=float)
515+
516+
487517
def from_system_data(system, f_idx=0):
488518
ret = ""
489519
ret += "\n"
@@ -514,6 +544,16 @@ def from_system_data(system, f_idx=0):
514544
cell[2][1],
515545
) # noqa: UP031
516546
ret += "\n"
547+
548+
masses = _get_lammps_masses(system)
549+
if masses is not None:
550+
ret += "Masses\n"
551+
ret += "\n"
552+
mass_fmt = ptr_int_fmt + " " + ptr_float_fmt + " # %s\n" # noqa: UP031
553+
for ii, (mass, atom_name) in enumerate(zip(masses, system["atom_names"])):
554+
ret += mass_fmt % (ii + 1, mass, atom_name)
555+
ret += "\n"
556+
517557
ret += "Atoms # atomic\n"
518558
ret += "\n"
519559
coord_fmt = (

tests/test_lammps_lmp_dump.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
import tempfile
45
import unittest
56

67
import numpy as np
@@ -9,6 +10,9 @@
910

1011
from dpdata.lammps.lmp import rotate_to_lower_triangle
1112

13+
TEST_DIR = os.path.dirname(__file__)
14+
POSCAR_CONF_LMP = os.path.join(TEST_DIR, "poscars", "conf.lmp")
15+
1216

1317
class TestLmpDump(unittest.TestCase, TestPOSCARoh):
1418
def setUp(self):
@@ -100,5 +104,30 @@ def test_negative_diagonal(self):
100104
)
101105

102106

107+
class TestLmpDumpMasses(unittest.TestCase):
108+
def test_dump_known_elements_writes_masses(self):
109+
system = dpdata.System(POSCAR_CONF_LMP, type_map=["O", "H"])
110+
with tempfile.TemporaryDirectory() as tmpdir:
111+
output = os.path.join(tmpdir, "tmp_masses.lmp")
112+
system.to_lammps_lmp(output)
113+
with open(output) as f:
114+
content = f.read()
115+
116+
self.assertIn("Masses\n", content)
117+
self.assertIn(" 1 15.9994000000 # O", content)
118+
self.assertIn(" 2 1.0079400000 # H", content)
119+
self.assertLess(content.index("Masses\n"), content.index("Atoms # atomic\n"))
120+
121+
def test_dump_unknown_types_skips_masses(self):
122+
system = dpdata.System(POSCAR_CONF_LMP)
123+
with tempfile.TemporaryDirectory() as tmpdir:
124+
output = os.path.join(tmpdir, "tmp_unknown_types.lmp")
125+
system.to_lammps_lmp(output)
126+
with open(output) as f:
127+
content = f.read()
128+
129+
self.assertNotIn("Masses\n", content)
130+
131+
103132
if __name__ == "__main__":
104133
unittest.main()

0 commit comments

Comments
 (0)