Skip to content

Commit bbb2a49

Browse files
committed
minor revision according to coderabbit
1 parent be6c622 commit bbb2a49

3 files changed

Lines changed: 7 additions & 2 deletions

File tree

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def change_type_map(
152152
else None,
153153
)
154154
self.fitting_net.change_type_map(type_map=type_map)
155+
# Reinitialize fitting to get correct sel_type
156+
self.fitting_net.reinit_exclude(self.atom_exclude_types)
155157

156158
def has_message_passing(self) -> bool:
157159
"""Returns whether the atomic model has message passing."""

deepmd/tf/model/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
10111011
sel_type = np.setdiff1d(
10121012
full_type_list, atom_exclude_types, assume_unique=True
10131013
)
1014-
fitting_dict["sel_type"] = sel_type
1014+
fitting_dict["sel_type"] = sel_type.tolist()
10151015
fitting = Fitting.deserialize(fitting_dict, suffix=suffix)
10161016
# pass descriptor type embedding to model
10171017
if descriptor.explicit_ntypes:

source/tests/pt/model/test_get_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ def test_model_attr(self) -> None:
6161
},
6262
)
6363
full_type_list = np.arange(len(atomic_model.type_map), dtype=int)
64-
atom_exclude_types = np.setdiff1d(full_type_list, self.model.get_sel_type())
64+
atom_exclude_types = np.setdiff1d(
65+
full_type_list,
66+
self.model.get_sel_type(),
67+
).tolist()
6568
self.assertEqual(atom_exclude_types, [1])
6669
self.assertEqual(atomic_model.atom_exclude_types, [1])
6770
self.assertEqual(atomic_model.pair_exclude_types, [[1, 2]])

0 commit comments

Comments
 (0)