Skip to content

Commit 8014159

Browse files
Copilotnjzjz
andcommitted
refactor(tf): consolidate duplicate output dimension logic and remove unnecessary try-except
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent d8bbd00 commit 8014159

9 files changed

Lines changed: 16 additions & 10 deletions

File tree

deepmd/tf/model/model.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -829,8 +829,10 @@ def _get_dim_out(self):
829829
"""
830830
if self.model_type == "ener":
831831
return 1
832-
elif self.model_type in ["dipole", "polar"]:
832+
elif self.model_type == "dipole":
833833
return 3
834+
elif self.model_type == "polar":
835+
return 9
834836
elif self.model_type == "dos":
835837
return self.numb_dos
836838
else:
@@ -900,17 +902,17 @@ def _apply_out_bias_std(self, output, atype, natoms, coord, selected_atype=None)
900902
"""
901903
nframes = tf.shape(coord)[0]
902904

905+
# Get output dimension consistently
906+
nout = self._get_dim_out()
907+
903908
if selected_atype is not None:
904909
# For tensor models (dipole, polar) with selected atoms
905910
natomsel = tf.shape(selected_atype)[1]
906-
nout = self.get_out_size() # Use the model's output size method
907911
output_reshaped = tf.reshape(output, [nframes, natomsel, nout])
908912
atype_for_gather = selected_atype
909913
else:
910914
# For energy and DOS models with all atoms
911915
nloc = natoms[0]
912-
# Get output dimension
913-
nout = self._get_dim_out()
914916
if self.model_type == "dos":
915917
# DOS model: output shape [nframes * nloc * numb_dos]
916918
output_reshaped = tf.reshape(output, [nframes, nloc, nout])
@@ -1048,12 +1050,8 @@ def serialize(self, suffix: str = "") -> dict:
10481050
# Get output dimension
10491051
dim_out = self._get_dim_out()
10501052

1051-
# Try to serialize fitting, with fallback for uninitialized variables
1052-
try:
1053-
dict_fit = self.fitting.serialize(suffix=suffix)
1054-
except (AttributeError, TypeError):
1055-
# Fallback: create a minimal dict_fit
1056-
dict_fit = {"@variables": {}}
1053+
# Serialize fitting
1054+
dict_fit = self.fitting.serialize(suffix=suffix)
10571055

10581056
# Use the actual out_bias and out_std if they exist, otherwise create defaults
10591057
if self.out_bias is not None:

system/set.000/aparam.npy

224 Bytes
Binary file not shown.

system/set.000/box.npy

200 Bytes
Binary file not shown.

system/set.000/coord.npy

272 Bytes
Binary file not shown.

system/set.000/energy.npy

136 Bytes
Binary file not shown.

system/set.000/force.npy

272 Bytes
Binary file not shown.

system/set.000/fparam.npy

144 Bytes
Binary file not shown.

system/type.raw

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
0
2+
0
3+
1
4+
1
5+
1
6+
1

system/type_map.raw

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
foo
2+
bar

0 commit comments

Comments
 (0)