diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 3377ed2d51..5e26569524 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -826,19 +826,32 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": # deepcopy is not used for performance reasons data["fitting"] = data["fitting"].copy() data["fitting"]["@variables"] = data["fitting"]["@variables"].copy() - if ( - int(np.any(data["fitting"]["@variables"]["bias_atom_e"])) - + int(np.any(data["@variables"]["out_bias"])) - > 1 - ): - raise ValueError( - "fitting/@variables/bias_atom_e and @variables/out_bias should not be both non-zero" + + # For InvarFitting types (ener, dos, property), out_bias can be reshaped and added to bias_atom_e + # For GeneralFitting types (dipole, polar), out_bias and bias_atom_e have different purposes and shapes + # and should not be added together + fitting_type = data["fitting"].get("type", "ener") + if fitting_type in [ + "ener", + "dos", + "property", + ]: + # For InvarFitting types, use the original logic to reshape and add out_bias to bias_atom_e + if ( + int(np.any(data["fitting"]["@variables"]["bias_atom_e"])) + + int(np.any(data["@variables"]["out_bias"])) + > 1 + ): + raise ValueError( + "fitting/@variables/bias_atom_e and @variables/out_bias should not be both non-zero" + ) + data["fitting"]["@variables"]["bias_atom_e"] = data["fitting"][ + "@variables" + ]["bias_atom_e"] + data["@variables"]["out_bias"].reshape( + data["fitting"]["@variables"]["bias_atom_e"].shape ) - data["fitting"]["@variables"]["bias_atom_e"] = data["fitting"][ - "@variables" - ]["bias_atom_e"] + data["@variables"]["out_bias"].reshape( - data["fitting"]["@variables"]["bias_atom_e"].shape - ) + # For GeneralFitting types (dipole, polar), keep out_bias separate - don't add to bias_atom_e + # These fitting types have different bias structures that are incompatible fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix) # pass descriptor type embedding to model if descriptor.explicit_ntypes: @@ -887,7 +900,10 @@ def serialize(self, suffix: str = "") -> dict: ntypes = len(self.get_type_map()) dict_fit = self.fitting.serialize(suffix=suffix) - if dict_fit.get("@variables", {}).get("bias_atom_e") is not None: + if ( + dict_fit.get("@variables", {}).get("bias_atom_e") is not None + and dict_fit["dim_out"] == dict_fit["embedding_width"] + ): out_bias = dict_fit["@variables"]["bias_atom_e"].reshape( [1, ntypes, dict_fit["dim_out"]] )