Skip to content

Commit 38e8763

Browse files
SchrodingersCatttpre-commit-ci[bot]cursoragent
authored
feat: support zero-count elements in type_map for sort_atom_names (#912)
- Refactor `sort_atom_names` to correctly handle `type_map` with zero-count elements while preserving existing alphabetical sorting behavior. - Only validate atom types that actually appear (count > 0) against the provided `type_map`, **allowing new elements in `type_map` to be added with zero count**. - Improve robustness and clarity of atom type remapping logic, restoring original commenting style for maintainability. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Enforced mapping order for atom names; raises clear error if any active atom type is missing from the provided mapping. * Consistently reorders atom names, counts, and type indices to match a provided mapping or an alphabetical fallback, and preserves zero-count entries where appropriate. * **Tests** * Added unit tests covering mapping-based sorting, zero-count handling, missing-active-type errors, and alphabetical sorting. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 99ea3bb commit 38e8763

2 files changed

Lines changed: 139 additions & 21 deletions

File tree

dpdata/utils.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ def add_atom_names(data, atom_names):
6262

6363

6464
def sort_atom_names(data, type_map=None):
65-
"""Sort atom_names of the system and reorder atom_numbs and atom_types accoarding
65+
"""Sort atom_names of the system and reorder atom_numbs and atom_types according
6666
to atom_names. If type_map is not given, atom_names will be sorted by
67-
alphabetical order. If type_map is given, atom_names will be type_map.
67+
alphabetical order. If type_map is given, atom_names will be set to type_map,
68+
and zero-count elements are kept.
6869
6970
Parameters
7071
----------
@@ -74,28 +75,59 @@ def sort_atom_names(data, type_map=None):
7475
type_map
7576
"""
7677
if type_map is not None:
77-
# assign atom_names index to the specify order
78-
# atom_names must be a subset of type_map
79-
assert set(data["atom_names"]).issubset(set(type_map))
80-
# for the condition that type_map is a proper superset of atom_names
81-
# new_atoms = set(type_map) - set(data["atom_names"])
82-
new_atoms = [e for e in type_map if e not in data["atom_names"]]
83-
if new_atoms:
84-
data = add_atom_names(data, new_atoms)
85-
# index that will sort an array by type_map
86-
# a[as[a]] == b[as[b]] as == argsort
87-
# as[as[b]] == as^{-1}[b]
88-
# a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id]
89-
idx = np.argsort(data["atom_names"], kind="stable")[
90-
np.argsort(np.argsort(type_map, kind="stable"), kind="stable")
91-
]
78+
# assign atom_names index to the specified order
79+
# only active (numb > 0) atom names must be in type_map
80+
orig_names = data["atom_names"]
81+
orig_numbs = data["atom_numbs"]
82+
active_names = {name for name, numb in zip(orig_names, orig_numbs) if numb > 0}
83+
type_map_set = set(type_map)
84+
if not active_names.issubset(type_map_set):
85+
missing = active_names - type_map_set
86+
raise ValueError(f"Active atom types {missing} not in provided type_map.")
87+
88+
# for the condition that type_map is a proper superset of atom_names,
89+
# we allow new elements with atom_numb = 0.
90+
# Precompute name -> new index once to avoid repeated O(n_types)
91+
# type_map.index(...) calls (which would make the loop O(n_types^2)).
92+
name_to_new_idx = {name: i for i, name in enumerate(type_map)}
93+
94+
# Build new_numbs and the old->new lookup array in a single pass.
95+
# Old names absent from type_map have atom_numb == 0 (validated above)
96+
# and never appear in atom_types, so -1 is a harmless sentinel for
97+
# their slots in the lookup table.
98+
new_names = list(type_map)
99+
new_numbs = [0] * len(type_map)
100+
lookup = np.full(len(orig_names), -1, dtype=np.int64)
101+
for old_idx, name in enumerate(orig_names):
102+
new_idx = name_to_new_idx.get(name)
103+
if new_idx is not None:
104+
lookup[old_idx] = new_idx
105+
new_numbs[new_idx] = orig_numbs[old_idx]
106+
107+
# Remap atom_types with a single vectorized fancy-index operation
108+
# (O(n_atoms + n_types) instead of O(n_types * n_atoms)).
109+
old_types = np.asarray(data["atom_types"])
110+
new_types = lookup[old_types]
111+
112+
# update data in-place
113+
data["atom_names"] = new_names
114+
data["atom_numbs"] = new_numbs
115+
data["atom_types"] = new_types
116+
92117
else:
93118
# index that will sort an array by alphabetical order
119+
# idx = argsort(atom_names) --> atom_names[idx] is sorted
94120
idx = np.argsort(data["atom_names"], kind="stable")
95-
# sort atom_names, atom_numbs, atom_types by idx
96-
data["atom_names"] = list(np.array(data["atom_names"])[idx])
97-
data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx])
98-
data["atom_types"] = np.argsort(idx, kind="stable")[data["atom_types"]]
121+
# sort atom_names and atom_numbs by idx
122+
data["atom_names"] = list(np.array(data["atom_names"])[idx])
123+
data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx])
124+
# to update atom_types: we need the inverse permutation of idx
125+
# because if old_type = i, and atom_names[i] moves to position j,
126+
# then the new type should be j.
127+
# inv_idx = argsort(idx) satisfies: inv_idx[idx[i]] = i
128+
inv_idx = np.argsort(idx, kind="stable")
129+
data["atom_types"] = inv_idx[data["atom_types"]]
130+
99131
return data
100132

101133

tests/test_type_map_utils.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
5+
import numpy as np
6+
7+
from dpdata.utils import sort_atom_names
8+
9+
10+
class TestSortAtomNames(unittest.TestCase):
11+
def test_sort_atom_names_type_map(self):
12+
# Test basic functionality with type_map
13+
data = {
14+
"atom_names": ["H", "O"],
15+
"atom_numbs": [2, 1],
16+
"atom_types": np.array([1, 0, 0]),
17+
}
18+
type_map = ["O", "H"]
19+
result = sort_atom_names(data, type_map=type_map)
20+
21+
self.assertEqual(result["atom_names"], ["O", "H"])
22+
self.assertEqual(result["atom_numbs"], [1, 2])
23+
np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1]))
24+
25+
def test_sort_atom_names_type_map_with_zero_atoms(self):
26+
# Test with type_map that includes elements with zero atoms
27+
data = {
28+
"atom_names": ["H", "O"],
29+
"atom_numbs": [2, 1],
30+
"atom_types": np.array([1, 0, 0]),
31+
}
32+
type_map = ["O", "H", "C"] # C is not in atom_names but in type_map
33+
result = sort_atom_names(data, type_map=type_map)
34+
35+
self.assertEqual(result["atom_names"], ["O", "H", "C"])
36+
self.assertEqual(result["atom_numbs"], [1, 2, 0])
37+
np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1]))
38+
39+
def test_sort_atom_names_type_map_missing_active_types(self):
40+
# Test that ValueError is raised when active atom types are missing from type_map
41+
data = {
42+
"atom_names": ["H", "O"],
43+
"atom_numbs": [2, 1], # Both H and O are active (numb > 0)
44+
"atom_types": np.array([1, 0, 0]),
45+
}
46+
type_map = ["H"] # O is active but missing from type_map
47+
48+
with self.assertRaises(ValueError) as cm:
49+
sort_atom_names(data, type_map=type_map)
50+
51+
self.assertIn("Active atom types", str(cm.exception))
52+
self.assertIn("not in provided type_map", str(cm.exception))
53+
self.assertIn("O", str(cm.exception))
54+
55+
def test_sort_atom_names_without_type_map(self):
56+
# Test sorting without type_map (alphabetical order)
57+
data = {
58+
"atom_names": ["Zn", "O", "H"],
59+
"atom_numbs": [1, 1, 2],
60+
"atom_types": np.array([0, 1, 2, 2]),
61+
}
62+
result = sort_atom_names(data)
63+
64+
self.assertEqual(result["atom_names"], ["H", "O", "Zn"])
65+
self.assertEqual(result["atom_numbs"], [2, 1, 1])
66+
np.testing.assert_array_equal(result["atom_types"], np.array([2, 1, 0, 0]))
67+
68+
def test_sort_atom_names_with_zero_count_elements_removed(self):
69+
# Test the case where original elements are A B C, but counts are 0 1 2,
70+
# which should be able to map to B C (removing A which has count 0)
71+
# Example: A, B, C = Cl, O, C
72+
data = {
73+
"atom_names": ["Cl", "O", "C"],
74+
"atom_numbs": [0, 1, 2],
75+
"atom_types": np.array([1, 2, 2]),
76+
}
77+
type_map = ["O", "C"] # Cl is omitted because it has 0 atoms
78+
result = sort_atom_names(data, type_map=type_map)
79+
80+
self.assertEqual(result["atom_names"], ["O", "C"])
81+
self.assertEqual(result["atom_numbs"], [1, 2])
82+
np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1]))
83+
84+
85+
if __name__ == "__main__":
86+
unittest.main()

0 commit comments

Comments
 (0)