Skip to content

Commit 9e33657

Browse files
committed
fix(lammps): reject mismatched explicit masses
Reject explicit system["masses"] arrays unless they match atom_names exactly to avoid silently truncating and mis-mapping LAMMPS Masses entries. Add a regression test covering the mismatched-length case. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4)
1 parent 6d6dcda commit 9e33657

2 files changed

Lines changed: 24 additions & 2 deletions

File tree

dpdata/lammps/lmp.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,11 +502,24 @@ def _get_lammps_masses(system) -> np.ndarray | None:
502502
np.ndarray or None
503503
Per-type masses aligned with ``atom_names``. Returns ``None`` when the
504504
masses cannot be determined safely.
505+
506+
Raises
507+
------
508+
ValueError
509+
If explicit ``system["masses"]`` is present but does not match the
510+
length of ``atom_names``.
505511
"""
506512
atom_names = system["atom_names"]
507513
masses = system.get("masses")
508-
if masses is not None and len(masses) >= len(atom_names):
509-
return np.asarray(masses[: len(atom_names)], dtype=float)
514+
if masses is not None:
515+
masses = np.asarray(masses, dtype=float)
516+
if masses.ndim != 1 or len(masses) != len(atom_names):
517+
raise ValueError(
518+
'Explicit system["masses"] must be a 1D array with the same '
519+
'length as system["atom_names"] to write the LAMMPS Masses '
520+
"section."
521+
)
522+
return masses
510523

511524
if not all(name in ELEMENTS for name in atom_names):
512525
return None

tests/test_lammps_lmp_dump.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ def test_dump_unknown_types_skips_masses(self):
128128

129129
self.assertNotIn("Masses\n", content)
130130

131+
def test_dump_rejects_mismatched_explicit_masses(self):
132+
system = dpdata.System(POSCAR_CONF_LMP, type_map=["O", "H"])
133+
system.data["masses"] = np.array([15.9994, 1.00794, 99.0])
134+
135+
with tempfile.TemporaryDirectory() as tmpdir:
136+
output = os.path.join(tmpdir, "tmp_bad_masses.lmp")
137+
with self.assertRaisesRegex(ValueError, r'system\["masses"\]'):
138+
system.to_lammps_lmp(output)
139+
131140

132141
if __name__ == "__main__":
133142
unittest.main()

0 commit comments

Comments
 (0)