@@ -1036,6 +1036,136 @@ def _forward_sequential(
10361036 n_updated = self .list_update (n_update_list , "node" )
10371037 return n_updated , e_updated , a_updated
10381038
1039+ def _compute_node_sym (
1040+ self ,
1041+ edge_ebd : torch .Tensor ,
1042+ nei_node_ebd : torch .Tensor ,
1043+ h2 : torch .Tensor ,
1044+ nlist_mask : torch .Tensor ,
1045+ sw : torch .Tensor ,
1046+ n2e_index : torch .Tensor ,
1047+ nb : int ,
1048+ nloc : int ,
1049+ ) -> torch .Tensor :
1050+ """Compute node symmetrization update (grrg + drrd)."""
1051+ node_sym_list : list [torch .Tensor ] = []
1052+ node_sym_list .append (
1053+ self .symmetrization_op (
1054+ edge_ebd ,
1055+ h2 ,
1056+ nlist_mask ,
1057+ sw ,
1058+ self .axis_neuron ,
1059+ )
1060+ if not self .use_dynamic_sel
1061+ else self .symmetrization_op_dynamic (
1062+ edge_ebd ,
1063+ h2 ,
1064+ sw ,
1065+ owner = n2e_index ,
1066+ num_owner = nb * nloc ,
1067+ nb = nb ,
1068+ nloc = nloc ,
1069+ scale_factor = self .dynamic_e_sel ** (- 0.5 ),
1070+ axis_neuron = self .axis_neuron ,
1071+ )
1072+ )
1073+ node_sym_list .append (
1074+ self .symmetrization_op (
1075+ nei_node_ebd ,
1076+ h2 ,
1077+ nlist_mask ,
1078+ sw ,
1079+ self .axis_neuron ,
1080+ )
1081+ if not self .use_dynamic_sel
1082+ else self .symmetrization_op_dynamic (
1083+ nei_node_ebd ,
1084+ h2 ,
1085+ sw ,
1086+ owner = n2e_index ,
1087+ num_owner = nb * nloc ,
1088+ nb = nb ,
1089+ nloc = nloc ,
1090+ scale_factor = self .dynamic_e_sel ** (- 0.5 ),
1091+ axis_neuron = self .axis_neuron ,
1092+ )
1093+ )
1094+ return self .act (self .node_sym_linear (torch .cat (node_sym_list , dim = - 1 )))
1095+
1096+ def _compute_node_edge_message (
1097+ self ,
1098+ node_ebd : torch .Tensor ,
1099+ node_ebd_ext : torch .Tensor ,
1100+ edge_ebd : torch .Tensor ,
1101+ nei_node_ebd : torch .Tensor ,
1102+ sw : torch .Tensor ,
1103+ nlist : torch .Tensor ,
1104+ n2e_index : torch .Tensor ,
1105+ n_ext2e_index : torch .Tensor ,
1106+ nb : int ,
1107+ nloc : int ,
1108+ ) -> torch .Tensor :
1109+ """Compute node edge message and reduce over neighbor dimension."""
1110+ if not self .optim_update :
1111+ if not self .use_dynamic_sel :
1112+ edge_info = torch .cat (
1113+ [
1114+ torch .tile (node_ebd .unsqueeze (- 2 ), [1 , 1 , self .nnei , 1 ]),
1115+ nei_node_ebd ,
1116+ edge_ebd ,
1117+ ],
1118+ dim = - 1 ,
1119+ )
1120+ else :
1121+ edge_info = torch .cat (
1122+ [
1123+ torch .index_select (
1124+ node_ebd .reshape (- 1 , self .n_dim ), 0 , n2e_index
1125+ ),
1126+ nei_node_ebd ,
1127+ edge_ebd ,
1128+ ],
1129+ dim = - 1 ,
1130+ )
1131+ node_edge_update = self .act (
1132+ self .node_edge_linear (edge_info )
1133+ ) * sw .unsqueeze (- 1 )
1134+ else :
1135+ node_edge_update = self .act (
1136+ self .optim_edge_update (
1137+ node_ebd ,
1138+ node_ebd_ext ,
1139+ edge_ebd ,
1140+ nlist ,
1141+ "node" ,
1142+ )
1143+ if not self .use_dynamic_sel
1144+ else self .optim_edge_update_dynamic (
1145+ node_ebd ,
1146+ node_ebd_ext ,
1147+ edge_ebd ,
1148+ n2e_index ,
1149+ n_ext2e_index ,
1150+ "node" ,
1151+ )
1152+ ) * sw .unsqueeze (- 1 )
1153+
1154+ node_edge_update = (
1155+ (torch .sum (node_edge_update , dim = - 2 ) / self .nnei )
1156+ if not self .use_dynamic_sel
1157+ else (
1158+ aggregate (
1159+ node_edge_update ,
1160+ n2e_index ,
1161+ average = False ,
1162+ num_owner = nb * nloc ,
1163+ ).reshape (nb , nloc , node_edge_update .shape [- 1 ])
1164+ / self .dynamic_e_sel
1165+ )
1166+ )
1167+ return node_edge_update
1168+
10391169 def _compute_edge_self_update (
10401170 self ,
10411171 node_ebd : torch .Tensor ,
@@ -1281,136 +1411,6 @@ def _compute_edge_angle_reduction(
12811411
12821412 return self .act (self .edge_angle_linear2 (padding_edge_angle_update ))
12831413
1284- def _compute_node_edge_message (
1285- self ,
1286- node_ebd : torch .Tensor ,
1287- node_ebd_ext : torch .Tensor ,
1288- edge_ebd : torch .Tensor ,
1289- nei_node_ebd : torch .Tensor ,
1290- sw : torch .Tensor ,
1291- nlist : torch .Tensor ,
1292- n2e_index : torch .Tensor ,
1293- n_ext2e_index : torch .Tensor ,
1294- nb : int ,
1295- nloc : int ,
1296- ) -> torch .Tensor :
1297- """Compute node edge message and reduce over neighbor dimension."""
1298- if not self .optim_update :
1299- if not self .use_dynamic_sel :
1300- edge_info = torch .cat (
1301- [
1302- torch .tile (node_ebd .unsqueeze (- 2 ), [1 , 1 , self .nnei , 1 ]),
1303- nei_node_ebd ,
1304- edge_ebd ,
1305- ],
1306- dim = - 1 ,
1307- )
1308- else :
1309- edge_info = torch .cat (
1310- [
1311- torch .index_select (
1312- node_ebd .reshape (- 1 , self .n_dim ), 0 , n2e_index
1313- ),
1314- nei_node_ebd ,
1315- edge_ebd ,
1316- ],
1317- dim = - 1 ,
1318- )
1319- node_edge_update = self .act (
1320- self .node_edge_linear (edge_info )
1321- ) * sw .unsqueeze (- 1 )
1322- else :
1323- node_edge_update = self .act (
1324- self .optim_edge_update (
1325- node_ebd ,
1326- node_ebd_ext ,
1327- edge_ebd ,
1328- nlist ,
1329- "node" ,
1330- )
1331- if not self .use_dynamic_sel
1332- else self .optim_edge_update_dynamic (
1333- node_ebd ,
1334- node_ebd_ext ,
1335- edge_ebd ,
1336- n2e_index ,
1337- n_ext2e_index ,
1338- "node" ,
1339- )
1340- ) * sw .unsqueeze (- 1 )
1341-
1342- node_edge_update = (
1343- (torch .sum (node_edge_update , dim = - 2 ) / self .nnei )
1344- if not self .use_dynamic_sel
1345- else (
1346- aggregate (
1347- node_edge_update ,
1348- n2e_index ,
1349- average = False ,
1350- num_owner = nb * nloc ,
1351- ).reshape (nb , nloc , node_edge_update .shape [- 1 ])
1352- / self .dynamic_e_sel
1353- )
1354- )
1355- return node_edge_update
1356-
1357- def _compute_node_sym (
1358- self ,
1359- edge_ebd : torch .Tensor ,
1360- nei_node_ebd : torch .Tensor ,
1361- h2 : torch .Tensor ,
1362- nlist_mask : torch .Tensor ,
1363- sw : torch .Tensor ,
1364- n2e_index : torch .Tensor ,
1365- nb : int ,
1366- nloc : int ,
1367- ) -> torch .Tensor :
1368- """Compute node symmetrization update (grrg + drrd)."""
1369- node_sym_list : list [torch .Tensor ] = []
1370- node_sym_list .append (
1371- self .symmetrization_op (
1372- edge_ebd ,
1373- h2 ,
1374- nlist_mask ,
1375- sw ,
1376- self .axis_neuron ,
1377- )
1378- if not self .use_dynamic_sel
1379- else self .symmetrization_op_dynamic (
1380- edge_ebd ,
1381- h2 ,
1382- sw ,
1383- owner = n2e_index ,
1384- num_owner = nb * nloc ,
1385- nb = nb ,
1386- nloc = nloc ,
1387- scale_factor = self .dynamic_e_sel ** (- 0.5 ),
1388- axis_neuron = self .axis_neuron ,
1389- )
1390- )
1391- node_sym_list .append (
1392- self .symmetrization_op (
1393- nei_node_ebd ,
1394- h2 ,
1395- nlist_mask ,
1396- sw ,
1397- self .axis_neuron ,
1398- )
1399- if not self .use_dynamic_sel
1400- else self .symmetrization_op_dynamic (
1401- nei_node_ebd ,
1402- h2 ,
1403- sw ,
1404- owner = n2e_index ,
1405- num_owner = nb * nloc ,
1406- nb = nb ,
1407- nloc = nloc ,
1408- scale_factor = self .dynamic_e_sel ** (- 0.5 ),
1409- axis_neuron = self .axis_neuron ,
1410- )
1411- )
1412- return self .act (self .node_sym_linear (torch .cat (node_sym_list , dim = - 1 )))
1413-
14141414 @torch .jit .export
14151415 def list_update_res_avg (
14161416 self ,
0 commit comments