Skip to content

Commit 4a82cfe

Browse files
Copilotnjzjz
andcommitted
refactor(tf): address review feedback - clean up hasattr checks and bias_atom_e fallback
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 0953edf commit 4a82cfe

1 file changed

Lines changed: 22 additions & 49 deletions

File tree

deepmd/tf/model/model.py

Lines changed: 22 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -819,35 +819,37 @@ def get_ntypes(self) -> int:
819819
"""Get the number of types."""
820820
return self.ntypes
821821

822+
def _get_dim_out(self):
823+
"""Get output dimension based on model type.
824+
825+
Returns
826+
-------
827+
int
828+
Output dimension
829+
"""
830+
if self.model_type == "ener":
831+
return 1
832+
elif self.model_type in ["dipole", "polar"]:
833+
return 3
834+
elif self.model_type == "dos":
835+
return self.numb_dos
836+
else:
837+
return 1
838+
822839
def init_out_stat(self, suffix: str = "") -> None:
823840
"""Initialize the output bias and std variables."""
824841
ntypes = self.get_ntypes()
825-
826-
# Determine output dimension based on model type instead of fitting type
827-
if hasattr(self, "model_type"):
828-
model_type = self.model_type
829-
else:
830-
# Fallback to fitting type for compatibility
831-
model_type = getattr(self.fitting, "model_type", "ener")
832-
833-
if model_type == "ener":
834-
dim_out = 1
835-
elif model_type in ["dipole", "polar"]:
836-
dim_out = 3
837-
elif model_type == "dos":
838-
dim_out = getattr(self.fitting, "numb_dos", 1)
839-
else:
840-
dim_out = 1
842+
dim_out = self._get_dim_out()
841843

842844
# Initialize out_bias and out_std as numpy arrays, preserving existing values if set
843-
if hasattr(self, "out_bias") and self.out_bias is not None:
845+
if self.out_bias is not None:
844846
out_bias_data = self.out_bias.copy()
845847
else:
846848
out_bias_data = np.zeros(
847849
[1, ntypes, dim_out], dtype=GLOBAL_NP_FLOAT_PRECISION
848850
)
849851

850-
if hasattr(self, "out_std") and self.out_std is not None:
852+
if self.out_std is not None:
851853
out_std_data = self.out_std.copy()
852854
else:
853855
out_std_data = np.ones(
@@ -907,7 +909,7 @@ def _apply_out_bias_std(self, output, atype, natoms, coord, selected_atype=None)
907909
else:
908910
# For energy and DOS models with all atoms
909911
nloc = natoms[0]
910-
if hasattr(self, "numb_dos"):
912+
if self.model_type == "dos":
911913
# DOS model: output shape [nframes * nloc * numb_dos]
912914
nout = self.numb_dos
913915
output_reshaped = tf.reshape(output, [nframes, nloc, nout])
@@ -1048,41 +1050,12 @@ def serialize(self, suffix: str = "") -> dict:
10481050
dict_fit = self.fitting.serialize(suffix=suffix)
10491051
except (AttributeError, TypeError):
10501052
# Fallback: create a minimal dict_fit with just dim_out
1051-
from deepmd.tf.fit.dipole import (
1052-
DipoleFittingSeA,
1053-
)
1054-
from deepmd.tf.fit.dos import (
1055-
DOSFitting,
1056-
)
1057-
from deepmd.tf.fit.ener import (
1058-
EnerFitting,
1059-
)
1060-
from deepmd.tf.fit.polar import (
1061-
PolarFittingSeA,
1062-
)
1063-
1064-
if isinstance(self.fitting, EnerFitting):
1065-
dim_out = 1
1066-
elif isinstance(self.fitting, (DipoleFittingSeA, PolarFittingSeA)):
1067-
dim_out = 3
1068-
elif isinstance(self.fitting, DOSFitting):
1069-
dim_out = getattr(self.fitting, "numb_dos", 1)
1070-
else:
1071-
dim_out = 1
1072-
1053+
dim_out = self._get_dim_out()
10731054
dict_fit = {"dim_out": dim_out, "@variables": {}}
10741055

10751056
# Use the actual out_bias and out_std if they exist, otherwise create defaults
10761057
if self.out_bias is not None:
10771058
out_bias = self.out_bias.copy()
1078-
elif dict_fit.get("@variables", {}).get("bias_atom_e") is not None:
1079-
# Fallback to converting bias_atom_e for backward compatibility
1080-
out_bias = dict_fit["@variables"]["bias_atom_e"].reshape(
1081-
[1, ntypes, dict_fit["dim_out"]]
1082-
)
1083-
dict_fit["@variables"]["bias_atom_e"] = np.zeros_like(
1084-
dict_fit["@variables"]["bias_atom_e"]
1085-
)
10861059
else:
10871060
out_bias = np.zeros(
10881061
[1, ntypes, dict_fit["dim_out"]], dtype=GLOBAL_NP_FLOAT_PRECISION

0 commit comments

Comments
 (0)