Skip to content

Commit 1369a46

Browse files
committed
Update repflow_layer.py
1 parent daf68ba commit 1369a46

1 file changed

Lines changed: 166 additions & 104 deletions

File tree

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 166 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,24 @@ def forward(
11501150
n_edge = h2.shape[0]
11511151
del a_nlist # may be used in the future
11521152

1153+
if self.sequential_update and self.update_angle:
1154+
return self._forward_sequential(
1155+
node_ebd_ext,
1156+
edge_ebd,
1157+
h2,
1158+
angle_ebd,
1159+
nlist,
1160+
nlist_mask,
1161+
sw,
1162+
a_nlist_mask,
1163+
a_sw,
1164+
edge_index,
1165+
angle_index,
1166+
nb,
1167+
nloc,
1168+
n_edge,
1169+
)
1170+
11531171
n2e_index, n_ext2e_index = edge_index[0], edge_index[1]
11541172
n2a_index, eij2a_index, eik2a_index = (
11551173
angle_index[0],
@@ -1166,95 +1184,6 @@ def forward(
11661184
)
11671185
)
11681186

1169-
# Edge self update (always from original embeddings)
1170-
edge_self_update = self._compute_edge_self_update(
1171-
node_ebd,
1172-
node_ebd_ext,
1173-
edge_ebd,
1174-
nei_node_ebd,
1175-
nlist,
1176-
n2e_index,
1177-
n_ext2e_index,
1178-
)
1179-
1180-
if self.sequential_update and self.update_angle:
1181-
# === Sequential update path ===
1182-
# Phase 1: Apply edge self residual
1183-
edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update
1184-
1185-
# Phase 2: Angle self (uses updated edge_ebd_s1)
1186-
node_for_a, edge_for_a = self._prepare_angle_embeddings(
1187-
node_ebd, edge_ebd_s1, a_nlist_mask
1188-
)
1189-
angle_self_update = self._compute_angle_update(
1190-
angle_ebd,
1191-
node_for_a,
1192-
edge_for_a,
1193-
"angle",
1194-
n2a_index,
1195-
eij2a_index,
1196-
eik2a_index,
1197-
)
1198-
a_updated = angle_ebd + self.a_residual[0] * angle_self_update
1199-
1200-
# Phase 3: Edge angle (uses updated angle a_updated + edge_ebd_s1)
1201-
edge_angle_update = self._compute_angle_update(
1202-
a_updated,
1203-
node_for_a,
1204-
edge_for_a,
1205-
"edge",
1206-
n2a_index,
1207-
eij2a_index,
1208-
eik2a_index,
1209-
)
1210-
edge_angle_processed = self._compute_edge_angle_reduction(
1211-
edge_angle_update,
1212-
edge_ebd_s1,
1213-
a_sw,
1214-
a_nlist_mask,
1215-
nb,
1216-
nloc,
1217-
n_edge,
1218-
eij2a_index,
1219-
)
1220-
e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed
1221-
1222-
# Phase 4+5: Node updates (uses e_updated)
1223-
node_self_mlp = self.act(self.node_self_mlp(node_ebd))
1224-
node_sym = self._compute_node_sym(
1225-
e_updated, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc
1226-
)
1227-
node_edge_update = self._compute_node_edge_message(
1228-
node_ebd,
1229-
node_ebd_ext,
1230-
e_updated,
1231-
nei_node_ebd,
1232-
sw,
1233-
nlist,
1234-
n2e_index,
1235-
n_ext2e_index,
1236-
nb,
1237-
nloc,
1238-
)
1239-
1240-
n_update_list: list[torch.Tensor] = [
1241-
node_ebd,
1242-
node_self_mlp,
1243-
node_sym,
1244-
]
1245-
if self.n_multi_edge_message > 1:
1246-
node_edge_update_mul_head = node_edge_update.view(
1247-
nb, nloc, self.n_multi_edge_message, self.n_dim
1248-
)
1249-
for head_index in range(self.n_multi_edge_message):
1250-
n_update_list.append(node_edge_update_mul_head[..., head_index, :])
1251-
else:
1252-
n_update_list.append(node_edge_update)
1253-
n_updated = self.list_update(n_update_list, "node")
1254-
1255-
return n_updated, e_updated, a_updated
1256-
1257-
# === Parallel update path ===
12581187
n_update_list: list[torch.Tensor] = [node_ebd]
12591188
e_update_list: list[torch.Tensor] = [edge_ebd]
12601189
a_update_list: list[torch.Tensor] = [angle_ebd]
@@ -1295,46 +1224,56 @@ def forward(
12951224
n_updated = self.list_update(n_update_list, "node")
12961225

12971226
# edge self message
1227+
edge_self_update = self._compute_edge_self_update(
1228+
node_ebd,
1229+
node_ebd_ext,
1230+
edge_ebd,
1231+
nei_node_ebd,
1232+
nlist,
1233+
n2e_index,
1234+
n_ext2e_index,
1235+
)
12981236
e_update_list.append(edge_self_update)
12991237

13001238
if self.update_angle:
13011239
assert self.angle_self_linear is not None
13021240
assert self.edge_angle_linear1 is not None
13031241
assert self.edge_angle_linear2 is not None
13041242

1305-
node_for_a, edge_for_a = self._prepare_angle_embeddings(
1243+
node_ebd_for_angle, edge_ebd_for_angle = self._prepare_angle_embeddings(
13061244
node_ebd, edge_ebd, a_nlist_mask
13071245
)
13081246

13091247
# edge angle message
13101248
edge_angle_update = self._compute_angle_update(
13111249
angle_ebd,
1312-
node_for_a,
1313-
edge_for_a,
1250+
node_ebd_for_angle,
1251+
edge_ebd_for_angle,
13141252
"edge",
13151253
n2a_index,
13161254
eij2a_index,
13171255
eik2a_index,
13181256
)
1319-
edge_angle_processed = self._compute_edge_angle_reduction(
1320-
edge_angle_update,
1321-
edge_ebd,
1322-
a_sw,
1323-
a_nlist_mask,
1324-
nb,
1325-
nloc,
1326-
n_edge,
1327-
eij2a_index,
1257+
e_update_list.append(
1258+
self._compute_edge_angle_reduction(
1259+
edge_angle_update,
1260+
edge_ebd,
1261+
a_sw,
1262+
a_nlist_mask,
1263+
nb,
1264+
nloc,
1265+
n_edge,
1266+
eij2a_index,
1267+
)
13281268
)
1329-
e_update_list.append(edge_angle_processed)
13301269
# update edge_ebd
13311270
e_updated = self.list_update(e_update_list, "edge")
13321271

13331272
# angle self message
13341273
angle_self_update = self._compute_angle_update(
13351274
angle_ebd,
1336-
node_for_a,
1337-
edge_for_a,
1275+
node_ebd_for_angle,
1276+
edge_ebd_for_angle,
13381277
"angle",
13391278
n2a_index,
13401279
eij2a_index,
@@ -1349,6 +1288,129 @@ def forward(
13491288
a_updated = self.list_update(a_update_list, "angle")
13501289
return n_updated, e_updated, a_updated
13511290

1291+
def _forward_sequential(
1292+
self,
1293+
node_ebd_ext: torch.Tensor,
1294+
edge_ebd: torch.Tensor,
1295+
h2: torch.Tensor,
1296+
angle_ebd: torch.Tensor,
1297+
nlist: torch.Tensor,
1298+
nlist_mask: torch.Tensor,
1299+
sw: torch.Tensor,
1300+
a_nlist_mask: torch.Tensor,
1301+
a_sw: torch.Tensor,
1302+
edge_index: torch.Tensor,
1303+
angle_index: torch.Tensor,
1304+
nb: int,
1305+
nloc: int,
1306+
n_edge: int | None,
1307+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1308+
"""Sequential update path: edge_self -> angle_self -> edge_angle -> node.
1309+
1310+
Each phase consumes the most recent update of its inputs (instead of
1311+
the parallel path which uses the same ``edge_ebd``/``angle_ebd`` for
1312+
all branches). Residuals are applied immediately after each phase.
1313+
"""
1314+
assert self.update_style == "res_residual"
1315+
assert self.update_angle
1316+
assert self.angle_self_linear is not None
1317+
assert self.edge_angle_linear1 is not None
1318+
assert self.edge_angle_linear2 is not None
1319+
node_ebd = node_ebd_ext[:, :nloc, :]
1320+
n2e_index, n_ext2e_index = edge_index[0], edge_index[1]
1321+
n2a_index, eij2a_index, eik2a_index = (
1322+
angle_index[0],
1323+
angle_index[1],
1324+
angle_index[2],
1325+
)
1326+
nei_node_ebd = (
1327+
_make_nei_g1(node_ebd_ext, nlist)
1328+
if not self.use_dynamic_sel
1329+
else torch.index_select(
1330+
node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index
1331+
)
1332+
)
1333+
1334+
# Phase 1: edge self update -> apply residual immediately.
1335+
edge_self_update = self._compute_edge_self_update(
1336+
node_ebd,
1337+
node_ebd_ext,
1338+
edge_ebd,
1339+
nei_node_ebd,
1340+
nlist,
1341+
n2e_index,
1342+
n_ext2e_index,
1343+
)
1344+
edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update
1345+
1346+
# Phase 2: angle self message uses the updated edge embedding.
1347+
node_ebd_for_angle, edge_ebd_for_angle = self._prepare_angle_embeddings(
1348+
node_ebd, edge_ebd_s1, a_nlist_mask
1349+
)
1350+
angle_self_update = self._compute_angle_update(
1351+
angle_ebd,
1352+
node_ebd_for_angle,
1353+
edge_ebd_for_angle,
1354+
"angle",
1355+
n2a_index,
1356+
eij2a_index,
1357+
eik2a_index,
1358+
)
1359+
a_updated = angle_ebd + self.a_residual[0] * angle_self_update
1360+
1361+
# Phase 3: edge angle message uses the updated angle embedding.
1362+
edge_angle_update = self._compute_angle_update(
1363+
a_updated,
1364+
node_ebd_for_angle,
1365+
edge_ebd_for_angle,
1366+
"edge",
1367+
n2a_index,
1368+
eij2a_index,
1369+
eik2a_index,
1370+
)
1371+
edge_angle_reduced = self._compute_edge_angle_reduction(
1372+
edge_angle_update,
1373+
edge_ebd_s1,
1374+
a_sw,
1375+
a_nlist_mask,
1376+
nb,
1377+
nloc,
1378+
n_edge,
1379+
eij2a_index,
1380+
)
1381+
e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_reduced
1382+
1383+
# Phase 4: node updates use the fully updated edge embedding.
1384+
n_update_list: list[torch.Tensor] = [node_ebd]
1385+
n_update_list.append(self.act(self.node_self_mlp(node_ebd)))
1386+
n_update_list.append(
1387+
self._compute_node_sym(
1388+
e_updated, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc
1389+
)
1390+
)
1391+
node_edge_update = self._compute_node_edge_message(
1392+
node_ebd,
1393+
node_ebd_ext,
1394+
e_updated,
1395+
nei_node_ebd,
1396+
sw,
1397+
nlist,
1398+
n2e_index,
1399+
n_ext2e_index,
1400+
nb,
1401+
nloc,
1402+
)
1403+
if self.n_multi_edge_message > 1:
1404+
node_edge_update_mul_head = node_edge_update.view(
1405+
nb, nloc, self.n_multi_edge_message, self.n_dim
1406+
)
1407+
for head_index in range(self.n_multi_edge_message):
1408+
n_update_list.append(node_edge_update_mul_head[..., head_index, :])
1409+
else:
1410+
n_update_list.append(node_edge_update)
1411+
n_updated = self.list_update(n_update_list, "node")
1412+
return n_updated, e_updated, a_updated
1413+
13521414
@torch.jit.export
13531415
def list_update_res_avg(
13541416
self,

0 commit comments

Comments
 (0)