Skip to content

Commit 2739439

Browse files
committed
fix(tf): add assertion for homogeneous batch in natoms_not_match
Add runtime check to ensure all frames have the same atype vector, preventing silent errors when ghost-atom layouts differ across frames.
1 parent cb0c636 commit 2739439

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

deepmd/tf/descriptor/se_a.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,7 +1391,7 @@ def init_variables(
13911391
)
13921392
)
13931393

1394-
def natoms_match(self, coord, natoms):
1394+
def natoms_match(self, coord: tf.Tensor, natoms: tf.Tensor) -> tf.Tensor:
13951395
natoms_index = tf.concat([[0], tf.cumsum(natoms[2:])], axis=0)
13961396
diff_coord_loc = []
13971397
for i in range(self.ntypes):
@@ -1415,10 +1415,20 @@ def natoms_match(self, coord, natoms):
14151415
diff_coord_loc = tf.concat(diff_coord_loc, axis=1)
14161416
return diff_coord_loc
14171417

1418-
def natoms_not_match(self, coord, natoms, atype):
1418+
def natoms_not_match(
1419+
self, coord: tf.Tensor, natoms: tf.Tensor, atype: tf.Tensor
1420+
) -> tf.Tensor:
14191421
diff_coord_loc = self.natoms_match(coord, natoms)
14201422
diff_coord_ghost = []
1421-
aatype = atype[0, :]
1423+
# Check that all frames have the same atype vector (homogeneous batch)
1424+
# to ensure ghost atom layout is consistent across frames.
1425+
atype_equal = tf.reduce_all(tf.equal(atype, atype[0:1, :]))
1426+
atype_equal = tf.Assert(
1427+
atype_equal,
1428+
["natoms_not_match requires all frames to have the same atype vector"],
1429+
)
1430+
with tf.control_dependencies([atype_equal]):
1431+
aatype = atype[0, :]
14221432
ghost_atype = aatype[natoms[0] :]
14231433
_, _, ghost_natoms = tf.unique_with_counts(ghost_atype)
14241434
ghost_natoms_index = tf.concat([[0], tf.cumsum(ghost_natoms)], axis=0)

0 commit comments

Comments
 (0)