@@ -663,6 +663,23 @@ def build(
663663 atype_t = tf .concat ([[self .ntypes ], tf .reshape (self .atype , [- 1 ])], axis = 0 )
664664 self .nei_type_vec = tf .nn .embedding_lookup (atype_t , nlist_t )
665665
666+ if self .spin is not None :
667+ judge = tf .equal (natoms [0 ], natoms [1 ])
668+ self .diff_coord = tf .cond (
669+ judge ,
670+ lambda : self .natoms_match (coord , natoms ),
671+ lambda : self .natoms_not_match (coord , natoms , atype ),
672+ )
673+ diff_coord_reshape = tf .reshape (self .diff_coord , [- 1 , natoms [1 ], 3 ])
674+ nlist_reshaped = tf .reshape (self .nlist , [- 1 , natoms [0 ], self .nnei ])
675+ rj_gathered = tf .gather (
676+ diff_coord_reshape , tf .maximum (nlist_reshaped , 0 ), axis = 1 , batch_dims = 1
677+ )
678+ ri_gathered = diff_coord_reshape [:, : natoms [0 ], tf .newaxis , :]
679+ rij_reshaped = tf .reshape (self .rij , [- 1 , natoms [0 ], self .nnei , 3 ])
680+ rij_update = rij_reshaped - rj_gathered + ri_gathered
681+ self .rij = tf .reshape (rij_update , tf .shape (self .rij ))
682+
666683 # only used when tensorboard was set as true
667684 tf .summary .histogram ("descrpt" , self .descrpt )
668685 tf .summary .histogram ("rij" , self .rij )
@@ -1355,6 +1372,62 @@ def init_variables(
13551372 )
13561373 )
13571374
1375+ def natoms_match (self , coord , natoms ):
1376+ natoms_index = tf .concat ([[0 ], tf .cumsum (natoms [2 :])], axis = 0 )
1377+ diff_coord_loc = []
1378+ for i in range (self .ntypes ):
1379+ if i + self .ntypes_spin >= self .ntypes :
1380+ diff_coord_loc .append (
1381+ tf .slice (
1382+ coord ,
1383+ [0 , natoms_index [i ] * 3 ],
1384+ [- 1 , natoms [2 + i ] * 3 ],
1385+ )
1386+ - tf .slice (
1387+ coord ,
1388+ [0 , natoms_index [i - len (self .spin .use_spin )] * 3 ],
1389+ [- 1 , natoms [2 + i - len (self .spin .use_spin )] * 3 ],
1390+ )
1391+ )
1392+ else :
1393+ diff_coord_loc .append (
1394+ tf .zeros ([tf .shape (coord )[0 ], natoms [2 + i ] * 3 ], dtype = coord .dtype )
1395+ )
1396+ diff_coord_loc = tf .concat (diff_coord_loc , axis = 1 )
1397+ return diff_coord_loc
1398+
1399+ def natoms_not_match (self , coord , natoms , atype ):
1400+ diff_coord_loc = self .natoms_match (coord , natoms )
1401+ diff_coord_ghost = []
1402+ aatype = atype [0 , :]
1403+ ghost_atype = aatype [natoms [0 ] :]
1404+ _ , _ , ghost_natoms = tf .unique_with_counts (ghost_atype )
1405+ ghost_natoms_index = tf .concat ([[0 ], tf .cumsum (ghost_natoms )], axis = 0 )
1406+ ghost_natoms_index += natoms [0 ]
1407+ for i in range (self .ntypes ):
1408+ if i + self .ntypes_spin >= self .ntypes :
1409+ diff_coord_ghost .append (
1410+ tf .slice (
1411+ coord ,
1412+ [0 , ghost_natoms_index [i ] * 3 ],
1413+ [- 1 , ghost_natoms [i ] * 3 ],
1414+ )
1415+ - tf .slice (
1416+ coord ,
1417+ [0 , ghost_natoms_index [i - len (self .spin .use_spin )] * 3 ],
1418+ [- 1 , ghost_natoms [i - len (self .spin .use_spin )] * 3 ],
1419+ )
1420+ )
1421+ else :
1422+ diff_coord_ghost .append (
1423+ tf .zeros (
1424+ [tf .shape (coord )[0 ], ghost_natoms [i ] * 3 ], dtype = coord .dtype
1425+ )
1426+ )
1427+ diff_coord_ghost = tf .concat (diff_coord_ghost , axis = 1 )
1428+ diff_coord = tf .concat ([diff_coord_loc , diff_coord_ghost ], axis = 1 )
1429+ return diff_coord
1430+
13581431 @property
13591432 def explicit_ntypes (self ) -> bool :
13601433 """Explicit ntypes with type embedding."""
0 commit comments