Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,19 +826,32 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
# deepcopy is not used for performance reasons
data["fitting"] = data["fitting"].copy()
data["fitting"]["@variables"] = data["fitting"]["@variables"].copy()
if (
int(np.any(data["fitting"]["@variables"]["bias_atom_e"]))
+ int(np.any(data["@variables"]["out_bias"]))
> 1
):
raise ValueError(
"fitting/@variables/bias_atom_e and @variables/out_bias should not be both non-zero"

# For InvarFitting types (ener, dos, property), out_bias can be reshaped and added to bias_atom_e
# For GeneralFitting types (dipole, polar), out_bias and bias_atom_e have different purposes and shapes
# and should not be added together
fitting_type = data["fitting"].get("type", "ener")
if fitting_type in [
"ener",
"dos",
"property",
]:
# For InvarFitting types, use the original logic to reshape and add out_bias to bias_atom_e
if (
int(np.any(data["fitting"]["@variables"]["bias_atom_e"]))
+ int(np.any(data["@variables"]["out_bias"]))
> 1
):
raise ValueError(
"fitting/@variables/bias_atom_e and @variables/out_bias should not be both non-zero"
)
data["fitting"]["@variables"]["bias_atom_e"] = data["fitting"][
"@variables"
]["bias_atom_e"] + data["@variables"]["out_bias"].reshape(
data["fitting"]["@variables"]["bias_atom_e"].shape
)
Comment thread
njzjz marked this conversation as resolved.
data["fitting"]["@variables"]["bias_atom_e"] = data["fitting"][
"@variables"
]["bias_atom_e"] + data["@variables"]["out_bias"].reshape(
data["fitting"]["@variables"]["bias_atom_e"].shape
)
# For GeneralFitting types (dipole, polar), keep out_bias separate - don't add to bias_atom_e
# These fitting types have different bias structures that are incompatible
fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix)
# pass descriptor type embedding to model
if descriptor.explicit_ntypes:
Expand Down Expand Up @@ -887,7 +900,10 @@ def serialize(self, suffix: str = "") -> dict:

ntypes = len(self.get_type_map())
dict_fit = self.fitting.serialize(suffix=suffix)
if dict_fit.get("@variables", {}).get("bias_atom_e") is not None:
if (
dict_fit.get("@variables", {}).get("bias_atom_e") is not None
and dict_fit["dim_out"] == dict_fit["embedding_width"]
):
out_bias = dict_fit["@variables"]["bias_atom_e"].reshape(
[1, ntypes, dict_fit["dim_out"]]
)
Expand Down
Loading