@@ -670,6 +670,23 @@ def build(
670670 atype_t = tf .concat ([[self .ntypes ], tf .reshape (self .atype , [- 1 ])], axis = 0 )
671671 self .nei_type_vec = tf .nn .embedding_lookup (atype_t , nlist_t )
672672
673+ if self .spin is not None :
674+ judge = tf .equal (natoms [0 ], natoms [1 ])
675+ self .diff_coord = tf .cond (
676+ judge ,
677+ lambda : self .natoms_match (coord , natoms ),
678+ lambda : self .natoms_not_match (coord , natoms , atype ),
679+ )
680+ diff_coord_reshape = tf .reshape (self .diff_coord , [- 1 , natoms [1 ], 3 ])
681+ nlist_reshaped = tf .reshape (self .nlist , [- 1 , natoms [0 ], self .nnei ])
682+ rj_gathered = tf .gather (
683+ diff_coord_reshape , tf .maximum (nlist_reshaped , 0 ), axis = 1 , batch_dims = 1
684+ )
685+ ri_gathered = diff_coord_reshape [:, : natoms [0 ], tf .newaxis , :]
686+ rij_reshaped = tf .reshape (self .rij , [- 1 , natoms [0 ], self .nnei , 3 ])
687+ rij_update = rij_reshaped - rj_gathered + ri_gathered
688+ self .rij = tf .reshape (rij_update , tf .shape (self .rij ))
689+
673690 # only used when tensorboard was set as true
674691 tf .summary .histogram ("descrpt" , self .descrpt )
675692 tf .summary .histogram ("rij" , self .rij )
@@ -1374,6 +1391,62 @@ def init_variables(
13741391 )
13751392 )
13761393
1394+ def natoms_match (self , coord , natoms ):
1395+ natoms_index = tf .concat ([[0 ], tf .cumsum (natoms [2 :])], axis = 0 )
1396+ diff_coord_loc = []
1397+ for i in range (self .ntypes ):
1398+ if i + self .ntypes_spin >= self .ntypes :
1399+ diff_coord_loc .append (
1400+ tf .slice (
1401+ coord ,
1402+ [0 , natoms_index [i ] * 3 ],
1403+ [- 1 , natoms [2 + i ] * 3 ],
1404+ )
1405+ - tf .slice (
1406+ coord ,
1407+ [0 , natoms_index [i - len (self .spin .use_spin )] * 3 ],
1408+ [- 1 , natoms [2 + i - len (self .spin .use_spin )] * 3 ],
1409+ )
1410+ )
1411+ else :
1412+ diff_coord_loc .append (
1413+ tf .zeros ([tf .shape (coord )[0 ], natoms [2 + i ] * 3 ], dtype = coord .dtype )
1414+ )
1415+ diff_coord_loc = tf .concat (diff_coord_loc , axis = 1 )
1416+ return diff_coord_loc
1417+
1418+ def natoms_not_match (self , coord , natoms , atype ):
1419+ diff_coord_loc = self .natoms_match (coord , natoms )
1420+ diff_coord_ghost = []
1421+ aatype = atype [0 , :]
1422+ ghost_atype = aatype [natoms [0 ] :]
1423+ _ , _ , ghost_natoms = tf .unique_with_counts (ghost_atype )
1424+ ghost_natoms_index = tf .concat ([[0 ], tf .cumsum (ghost_natoms )], axis = 0 )
1425+ ghost_natoms_index += natoms [0 ]
1426+ for i in range (self .ntypes ):
1427+ if i + self .ntypes_spin >= self .ntypes :
1428+ diff_coord_ghost .append (
1429+ tf .slice (
1430+ coord ,
1431+ [0 , ghost_natoms_index [i ] * 3 ],
1432+ [- 1 , ghost_natoms [i ] * 3 ],
1433+ )
1434+ - tf .slice (
1435+ coord ,
1436+ [0 , ghost_natoms_index [i - len (self .spin .use_spin )] * 3 ],
1437+ [- 1 , ghost_natoms [i - len (self .spin .use_spin )] * 3 ],
1438+ )
1439+ )
1440+ else :
1441+ diff_coord_ghost .append (
1442+ tf .zeros (
1443+ [tf .shape (coord )[0 ], ghost_natoms [i ] * 3 ], dtype = coord .dtype
1444+ )
1445+ )
1446+ diff_coord_ghost = tf .concat (diff_coord_ghost , axis = 1 )
1447+ diff_coord = tf .concat ([diff_coord_loc , diff_coord_ghost ], axis = 1 )
1448+ return diff_coord
1449+
13771450 @property
13781451 def explicit_ntypes (self ) -> bool :
13791452 """Explicit ntypes with type embedding."""
0 commit comments