@@ -1150,6 +1150,24 @@ def forward(
11501150 n_edge = h2 .shape [0 ]
11511151 del a_nlist # may be used in the future
11521152
1153+ if self .sequential_update and self .update_angle :
1154+ return self ._forward_sequential (
1155+ node_ebd_ext ,
1156+ edge_ebd ,
1157+ h2 ,
1158+ angle_ebd ,
1159+ nlist ,
1160+ nlist_mask ,
1161+ sw ,
1162+ a_nlist_mask ,
1163+ a_sw ,
1164+ edge_index ,
1165+ angle_index ,
1166+ nb ,
1167+ nloc ,
1168+ n_edge ,
1169+ )
1170+
11531171 n2e_index , n_ext2e_index = edge_index [0 ], edge_index [1 ]
11541172 n2a_index , eij2a_index , eik2a_index = (
11551173 angle_index [0 ],
@@ -1166,95 +1184,6 @@ def forward(
11661184 )
11671185 )
11681186
1169- # Edge self update (always from original embeddings)
1170- edge_self_update = self ._compute_edge_self_update (
1171- node_ebd ,
1172- node_ebd_ext ,
1173- edge_ebd ,
1174- nei_node_ebd ,
1175- nlist ,
1176- n2e_index ,
1177- n_ext2e_index ,
1178- )
1179-
1180- if self .sequential_update and self .update_angle :
1181- # === Sequential update path ===
1182- # Phase 1: Apply edge self residual
1183- edge_ebd_s1 = edge_ebd + self .e_residual [0 ] * edge_self_update
1184-
1185- # Phase 2: Angle self (uses updated edge_ebd_s1)
1186- node_for_a , edge_for_a = self ._prepare_angle_embeddings (
1187- node_ebd , edge_ebd_s1 , a_nlist_mask
1188- )
1189- angle_self_update = self ._compute_angle_update (
1190- angle_ebd ,
1191- node_for_a ,
1192- edge_for_a ,
1193- "angle" ,
1194- n2a_index ,
1195- eij2a_index ,
1196- eik2a_index ,
1197- )
1198- a_updated = angle_ebd + self .a_residual [0 ] * angle_self_update
1199-
1200- # Phase 3: Edge angle (uses updated angle a_updated + edge_ebd_s1)
1201- edge_angle_update = self ._compute_angle_update (
1202- a_updated ,
1203- node_for_a ,
1204- edge_for_a ,
1205- "edge" ,
1206- n2a_index ,
1207- eij2a_index ,
1208- eik2a_index ,
1209- )
1210- edge_angle_processed = self ._compute_edge_angle_reduction (
1211- edge_angle_update ,
1212- edge_ebd_s1 ,
1213- a_sw ,
1214- a_nlist_mask ,
1215- nb ,
1216- nloc ,
1217- n_edge ,
1218- eij2a_index ,
1219- )
1220- e_updated = edge_ebd_s1 + self .e_residual [1 ] * edge_angle_processed
1221-
1222- # Phase 4+5: Node updates (uses e_updated)
1223- node_self_mlp = self .act (self .node_self_mlp (node_ebd ))
1224- node_sym = self ._compute_node_sym (
1225- e_updated , nei_node_ebd , h2 , nlist_mask , sw , n2e_index , nb , nloc
1226- )
1227- node_edge_update = self ._compute_node_edge_message (
1228- node_ebd ,
1229- node_ebd_ext ,
1230- e_updated ,
1231- nei_node_ebd ,
1232- sw ,
1233- nlist ,
1234- n2e_index ,
1235- n_ext2e_index ,
1236- nb ,
1237- nloc ,
1238- )
1239-
1240- n_update_list : list [torch .Tensor ] = [
1241- node_ebd ,
1242- node_self_mlp ,
1243- node_sym ,
1244- ]
1245- if self .n_multi_edge_message > 1 :
1246- node_edge_update_mul_head = node_edge_update .view (
1247- nb , nloc , self .n_multi_edge_message , self .n_dim
1248- )
1249- for head_index in range (self .n_multi_edge_message ):
1250- n_update_list .append (node_edge_update_mul_head [..., head_index , :])
1251- else :
1252- n_update_list .append (node_edge_update )
1253- n_updated = self .list_update (n_update_list , "node" )
1254-
1255- return n_updated , e_updated , a_updated
1256-
1257- # === Parallel update path ===
12581187 n_update_list : list [torch .Tensor ] = [node_ebd ]
12591188 e_update_list : list [torch .Tensor ] = [edge_ebd ]
12601189 a_update_list : list [torch .Tensor ] = [angle_ebd ]
@@ -1295,46 +1224,56 @@ def forward(
12951224 n_updated = self .list_update (n_update_list , "node" )
12961225
12971226 # edge self message
1227+ edge_self_update = self ._compute_edge_self_update (
1228+ node_ebd ,
1229+ node_ebd_ext ,
1230+ edge_ebd ,
1231+ nei_node_ebd ,
1232+ nlist ,
1233+ n2e_index ,
1234+ n_ext2e_index ,
1235+ )
12981236 e_update_list .append (edge_self_update )
12991237
13001238 if self .update_angle :
13011239 assert self .angle_self_linear is not None
13021240 assert self .edge_angle_linear1 is not None
13031241 assert self .edge_angle_linear2 is not None
13041242
1305- node_for_a , edge_for_a = self ._prepare_angle_embeddings (
1243+ node_ebd_for_angle , edge_ebd_for_angle = self ._prepare_angle_embeddings (
13061244 node_ebd , edge_ebd , a_nlist_mask
13071245 )
13081246
13091247 # edge angle message
13101248 edge_angle_update = self ._compute_angle_update (
13111249 angle_ebd ,
1312- node_for_a ,
1313- edge_for_a ,
1250+ node_ebd_for_angle ,
1251+ edge_ebd_for_angle ,
13141252 "edge" ,
13151253 n2a_index ,
13161254 eij2a_index ,
13171255 eik2a_index ,
13181256 )
1319- edge_angle_processed = self ._compute_edge_angle_reduction (
1320- edge_angle_update ,
1321- edge_ebd ,
1322- a_sw ,
1323- a_nlist_mask ,
1324- nb ,
1325- nloc ,
1326- n_edge ,
1327- eij2a_index ,
1257+ e_update_list .append (
1258+ self ._compute_edge_angle_reduction (
1259+ edge_angle_update ,
1260+ edge_ebd ,
1261+ a_sw ,
1262+ a_nlist_mask ,
1263+ nb ,
1264+ nloc ,
1265+ n_edge ,
1266+ eij2a_index ,
1267+ )
13281268 )
1329- e_update_list .append (edge_angle_processed )
13301269 # update edge_ebd
13311270 e_updated = self .list_update (e_update_list , "edge" )
13321271
13331272 # angle self message
13341273 angle_self_update = self ._compute_angle_update (
13351274 angle_ebd ,
1336- node_for_a ,
1337- edge_for_a ,
1275+ node_ebd_for_angle ,
1276+ edge_ebd_for_angle ,
13381277 "angle" ,
13391278 n2a_index ,
13401279 eij2a_index ,
@@ -1349,6 +1288,129 @@ def forward(
13491288 a_updated = self .list_update (a_update_list , "angle" )
13501289 return n_updated , e_updated , a_updated
13511290
1291+ def _forward_sequential (
1292+ self ,
1293+ node_ebd_ext : torch .Tensor ,
1294+ edge_ebd : torch .Tensor ,
1295+ h2 : torch .Tensor ,
1296+ angle_ebd : torch .Tensor ,
1297+ nlist : torch .Tensor ,
1298+ nlist_mask : torch .Tensor ,
1299+ sw : torch .Tensor ,
1300+ a_nlist_mask : torch .Tensor ,
1301+ a_sw : torch .Tensor ,
1302+ edge_index : torch .Tensor ,
1303+ angle_index : torch .Tensor ,
1304+ nb : int ,
1305+ nloc : int ,
1306+ n_edge : int | None ,
1307+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
1308+ """Sequential update path: edge_self -> angle_self -> edge_angle -> node.
1309+
1310+ Each phase consumes the most recent update of its inputs (instead of
1311+ the parallel path which uses the same ``edge_ebd``/``angle_ebd`` for
1312+ all branches). Residuals are applied immediately after each phase.
1313+ """
1314+ assert self .update_style == "res_residual"
1315+ assert self .update_angle
1316+ assert self .angle_self_linear is not None
1317+ assert self .edge_angle_linear1 is not None
1318+ assert self .edge_angle_linear2 is not None
1319+ node_ebd = node_ebd_ext [:, :nloc , :]
1320+ n2e_index , n_ext2e_index = edge_index [0 ], edge_index [1 ]
1321+ n2a_index , eij2a_index , eik2a_index = (
1322+ angle_index [0 ],
1323+ angle_index [1 ],
1324+ angle_index [2 ],
1325+ )
1326+ nei_node_ebd = (
1327+ _make_nei_g1 (node_ebd_ext , nlist )
1328+ if not self .use_dynamic_sel
1329+ else torch .index_select (
1330+ node_ebd_ext .reshape (- 1 , self .n_dim ), 0 , n_ext2e_index
1331+ )
1332+ )
1333+
1334+ # Phase 1: edge self update -> apply residual immediately.
1335+ edge_self_update = self ._compute_edge_self_update (
1336+ node_ebd ,
1337+ node_ebd_ext ,
1338+ edge_ebd ,
1339+ nei_node_ebd ,
1340+ nlist ,
1341+ n2e_index ,
1342+ n_ext2e_index ,
1343+ )
1344+ edge_ebd_s1 = edge_ebd + self .e_residual [0 ] * edge_self_update
1345+
1346+ # Phase 2: angle self message uses the updated edge embedding.
1347+ node_ebd_for_angle , edge_ebd_for_angle = self ._prepare_angle_embeddings (
1348+ node_ebd , edge_ebd_s1 , a_nlist_mask
1349+ )
1350+ angle_self_update = self ._compute_angle_update (
1351+ angle_ebd ,
1352+ node_ebd_for_angle ,
1353+ edge_ebd_for_angle ,
1354+ "angle" ,
1355+ n2a_index ,
1356+ eij2a_index ,
1357+ eik2a_index ,
1358+ )
1359+ a_updated = angle_ebd + self .a_residual [0 ] * angle_self_update
1360+
1361+ # Phase 3: edge angle message uses the updated angle embedding.
1362+ edge_angle_update = self ._compute_angle_update (
1363+ a_updated ,
1364+ node_ebd_for_angle ,
1365+ edge_ebd_for_angle ,
1366+ "edge" ,
1367+ n2a_index ,
1368+ eij2a_index ,
1369+ eik2a_index ,
1370+ )
1371+ edge_angle_reduced = self ._compute_edge_angle_reduction (
1372+ edge_angle_update ,
1373+ edge_ebd_s1 ,
1374+ a_sw ,
1375+ a_nlist_mask ,
1376+ nb ,
1377+ nloc ,
1378+ n_edge ,
1379+ eij2a_index ,
1380+ )
1381+ e_updated = edge_ebd_s1 + self .e_residual [1 ] * edge_angle_reduced
1382+
1383+ # Phase 4: node updates use the fully updated edge embedding.
1384+ n_update_list : list [torch .Tensor ] = [node_ebd ]
1385+ n_update_list .append (self .act (self .node_self_mlp (node_ebd )))
1386+ n_update_list .append (
1387+ self ._compute_node_sym (
1388+ e_updated , nei_node_ebd , h2 , nlist_mask , sw , n2e_index , nb , nloc
1389+ )
1390+ )
1391+ node_edge_update = self ._compute_node_edge_message (
1392+ node_ebd ,
1393+ node_ebd_ext ,
1394+ e_updated ,
1395+ nei_node_ebd ,
1396+ sw ,
1397+ nlist ,
1398+ n2e_index ,
1399+ n_ext2e_index ,
1400+ nb ,
1401+ nloc ,
1402+ )
1403+ if self .n_multi_edge_message > 1 :
1404+ node_edge_update_mul_head = node_edge_update .view (
1405+ nb , nloc , self .n_multi_edge_message , self .n_dim
1406+ )
1407+ for head_index in range (self .n_multi_edge_message ):
1408+ n_update_list .append (node_edge_update_mul_head [..., head_index , :])
1409+ else :
1410+ n_update_list .append (node_edge_update )
1411+ n_updated = self .list_update (n_update_list , "node" )
1412+ return n_updated , e_updated , a_updated
1413+
13521414 @torch .jit .export
13531415 def list_update_res_avg (
13541416 self ,
0 commit comments