@@ -1391,7 +1391,7 @@ def init_variables(
13911391 )
13921392 )
13931393
1394- def natoms_match (self , coord , natoms ) :
1394+ def natoms_match (self , coord : tf . Tensor , natoms : tf . Tensor ) -> tf . Tensor :
13951395 natoms_index = tf .concat ([[0 ], tf .cumsum (natoms [2 :])], axis = 0 )
13961396 diff_coord_loc = []
13971397 for i in range (self .ntypes ):
@@ -1415,10 +1415,20 @@ def natoms_match(self, coord, natoms):
14151415 diff_coord_loc = tf .concat (diff_coord_loc , axis = 1 )
14161416 return diff_coord_loc
14171417
1418- def natoms_not_match (self , coord , natoms , atype ):
1418+ def natoms_not_match (
1419+ self , coord : tf .Tensor , natoms : tf .Tensor , atype : tf .Tensor
1420+ ) -> tf .Tensor :
14191421 diff_coord_loc = self .natoms_match (coord , natoms )
14201422 diff_coord_ghost = []
1421- aatype = atype [0 , :]
1423+ # Check that all frames have the same atype vector (homogeneous batch)
1424+ # to ensure ghost atom layout is consistent across frames.
1425+ atype_equal = tf .reduce_all (tf .equal (atype , atype [0 :1 , :]))
1426+ atype_equal = tf .Assert (
1427+ atype_equal ,
1428+ ["natoms_not_match requires all frames to have the same atype vector" ],
1429+ )
1430+ with tf .control_dependencies ([atype_equal ]):
1431+ aatype = atype [0 , :]
14221432 ghost_atype = aatype [natoms [0 ] :]
14231433 _ , _ , ghost_natoms = tf .unique_with_counts (ghost_atype )
14241434 ghost_natoms_index = tf .concat ([[0 ], tf .cumsum (ghost_natoms )], axis = 0 )
0 commit comments