Skip to content

Commit d8bbd00

Browse files
Copilotnjzjz
andcommitted
refactor(tf): use _get_dim_out() method consistently to eliminate duplicate dim calculation logic
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 4a82cfe commit d8bbd00

1 file changed

Lines changed: 9 additions & 11 deletions

File tree

deepmd/tf/model/model.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -909,13 +909,13 @@ def _apply_out_bias_std(self, output, atype, natoms, coord, selected_atype=None)
909909
else:
910910
# For energy and DOS models with all atoms
911911
nloc = natoms[0]
912+
# Get output dimension
913+
nout = self._get_dim_out()
912914
if self.model_type == "dos":
913915
# DOS model: output shape [nframes * nloc * numb_dos]
914-
nout = self.numb_dos
915916
output_reshaped = tf.reshape(output, [nframes, nloc, nout])
916917
else:
917918
# Energy model: output shape [nframes * nloc]
918-
nout = 1
919919
output_reshaped = tf.reshape(output, [nframes, nloc, 1])
920920
atype_for_gather = tf.reshape(atype, [nframes, nloc])
921921

@@ -1045,28 +1045,26 @@ def serialize(self, suffix: str = "") -> dict:
10451045

10461046
ntypes = len(self.get_type_map())
10471047

1048+
# Get output dimension
1049+
dim_out = self._get_dim_out()
1050+
10481051
# Try to serialize fitting, with fallback for uninitialized variables
10491052
try:
10501053
dict_fit = self.fitting.serialize(suffix=suffix)
10511054
except (AttributeError, TypeError):
1052-
# Fallback: create a minimal dict_fit with just dim_out
1053-
dim_out = self._get_dim_out()
1054-
dict_fit = {"dim_out": dim_out, "@variables": {}}
1055+
# Fallback: create a minimal dict_fit
1056+
dict_fit = {"@variables": {}}
10551057

10561058
# Use the actual out_bias and out_std if they exist, otherwise create defaults
10571059
if self.out_bias is not None:
10581060
out_bias = self.out_bias.copy()
10591061
else:
1060-
out_bias = np.zeros(
1061-
[1, ntypes, dict_fit["dim_out"]], dtype=GLOBAL_NP_FLOAT_PRECISION
1062-
)
1062+
out_bias = np.zeros([1, ntypes, dim_out], dtype=GLOBAL_NP_FLOAT_PRECISION)
10631063

10641064
if self.out_std is not None:
10651065
out_std = self.out_std.copy()
10661066
else:
1067-
out_std = np.ones(
1068-
[1, ntypes, dict_fit["dim_out"]], dtype=GLOBAL_NP_FLOAT_PRECISION
1069-
)
1067+
out_std = np.ones([1, ntypes, dim_out], dtype=GLOBAL_NP_FLOAT_PRECISION)
10701068
return {
10711069
"@class": "Model",
10721070
"type": "standard",

0 commit comments

Comments
 (0)