Skip to content

Commit dfc5f03

Browse files
committed
pt to tf: get sel_type from complement of atom_exclude_types
1 parent b9c82d7 commit dfc5f03

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

deepmd/tf/model/model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,16 +1003,23 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
10031003
check_version_compatibility(data.pop("@version", 2), 2, 1)
10041004
descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix)
10051005
# bias_atom_e and out_bias are now completely independent - no conversion needed
1006-
fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix)
1006+
fitting_dict = data.pop("fitting", {})
1007+
atom_exclude_types = data.pop("atom_exclude_types", [])
1008+
if len(atom_exclude_types) > 0:
1009+
# get sel_type from complement of atom_exclude_types
1010+
full_type_list = np.arange(len(data["type_map"]), dtype=int)
1011+
sel_type = np.setdiff1d(
1012+
full_type_list, atom_exclude_types, assume_unique=True
1013+
)
1014+
fitting_dict["sel_type"] = sel_type
1015+
fitting = Fitting.deserialize(fitting_dict, suffix=suffix)
10071016
# pass descriptor type embedding to model
10081017
if descriptor.explicit_ntypes:
10091018
type_embedding = descriptor.type_embedding
10101019
fitting.dim_descrpt -= type_embedding.neuron[-1]
10111020
else:
10121021
type_embedding = None
10131022
# BEGINE not supported keys
1014-
if len(data.pop("atom_exclude_types")) > 0:
1015-
raise NotImplementedError("atom_exclude_types is not supported")
10161023
if len(data.pop("pair_exclude_types")) > 0:
10171024
raise NotImplementedError("pair_exclude_types is not supported")
10181025
data.pop("rcond", None)

0 commit comments

Comments
 (0)