Skip to content

Commit a54dce5

Browse files
committed
Update repflow_layer.py
1 parent 9cb9ce1 commit a54dce5

1 file changed

Lines changed: 130 additions & 130 deletions

File tree

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 130 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,136 @@ def _forward_sequential(
10361036
n_updated = self.list_update(n_update_list, "node")
10371037
return n_updated, e_updated, a_updated
10381038

1039+
def _compute_node_sym(
1040+
self,
1041+
edge_ebd: torch.Tensor,
1042+
nei_node_ebd: torch.Tensor,
1043+
h2: torch.Tensor,
1044+
nlist_mask: torch.Tensor,
1045+
sw: torch.Tensor,
1046+
n2e_index: torch.Tensor,
1047+
nb: int,
1048+
nloc: int,
1049+
) -> torch.Tensor:
1050+
"""Compute node symmetrization update (grrg + drrd)."""
1051+
node_sym_list: list[torch.Tensor] = []
1052+
node_sym_list.append(
1053+
self.symmetrization_op(
1054+
edge_ebd,
1055+
h2,
1056+
nlist_mask,
1057+
sw,
1058+
self.axis_neuron,
1059+
)
1060+
if not self.use_dynamic_sel
1061+
else self.symmetrization_op_dynamic(
1062+
edge_ebd,
1063+
h2,
1064+
sw,
1065+
owner=n2e_index,
1066+
num_owner=nb * nloc,
1067+
nb=nb,
1068+
nloc=nloc,
1069+
scale_factor=self.dynamic_e_sel ** (-0.5),
1070+
axis_neuron=self.axis_neuron,
1071+
)
1072+
)
1073+
node_sym_list.append(
1074+
self.symmetrization_op(
1075+
nei_node_ebd,
1076+
h2,
1077+
nlist_mask,
1078+
sw,
1079+
self.axis_neuron,
1080+
)
1081+
if not self.use_dynamic_sel
1082+
else self.symmetrization_op_dynamic(
1083+
nei_node_ebd,
1084+
h2,
1085+
sw,
1086+
owner=n2e_index,
1087+
num_owner=nb * nloc,
1088+
nb=nb,
1089+
nloc=nloc,
1090+
scale_factor=self.dynamic_e_sel ** (-0.5),
1091+
axis_neuron=self.axis_neuron,
1092+
)
1093+
)
1094+
return self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1)))
1095+
1096+
def _compute_node_edge_message(
1097+
self,
1098+
node_ebd: torch.Tensor,
1099+
node_ebd_ext: torch.Tensor,
1100+
edge_ebd: torch.Tensor,
1101+
nei_node_ebd: torch.Tensor,
1102+
sw: torch.Tensor,
1103+
nlist: torch.Tensor,
1104+
n2e_index: torch.Tensor,
1105+
n_ext2e_index: torch.Tensor,
1106+
nb: int,
1107+
nloc: int,
1108+
) -> torch.Tensor:
1109+
"""Compute node edge message and reduce over neighbor dimension."""
1110+
if not self.optim_update:
1111+
if not self.use_dynamic_sel:
1112+
edge_info = torch.cat(
1113+
[
1114+
torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]),
1115+
nei_node_ebd,
1116+
edge_ebd,
1117+
],
1118+
dim=-1,
1119+
)
1120+
else:
1121+
edge_info = torch.cat(
1122+
[
1123+
torch.index_select(
1124+
node_ebd.reshape(-1, self.n_dim), 0, n2e_index
1125+
),
1126+
nei_node_ebd,
1127+
edge_ebd,
1128+
],
1129+
dim=-1,
1130+
)
1131+
node_edge_update = self.act(
1132+
self.node_edge_linear(edge_info)
1133+
) * sw.unsqueeze(-1)
1134+
else:
1135+
node_edge_update = self.act(
1136+
self.optim_edge_update(
1137+
node_ebd,
1138+
node_ebd_ext,
1139+
edge_ebd,
1140+
nlist,
1141+
"node",
1142+
)
1143+
if not self.use_dynamic_sel
1144+
else self.optim_edge_update_dynamic(
1145+
node_ebd,
1146+
node_ebd_ext,
1147+
edge_ebd,
1148+
n2e_index,
1149+
n_ext2e_index,
1150+
"node",
1151+
)
1152+
) * sw.unsqueeze(-1)
1153+
1154+
node_edge_update = (
1155+
(torch.sum(node_edge_update, dim=-2) / self.nnei)
1156+
if not self.use_dynamic_sel
1157+
else (
1158+
aggregate(
1159+
node_edge_update,
1160+
n2e_index,
1161+
average=False,
1162+
num_owner=nb * nloc,
1163+
).reshape(nb, nloc, node_edge_update.shape[-1])
1164+
/ self.dynamic_e_sel
1165+
)
1166+
)
1167+
return node_edge_update
1168+
10391169
def _compute_edge_self_update(
10401170
self,
10411171
node_ebd: torch.Tensor,
@@ -1281,136 +1411,6 @@ def _compute_edge_angle_reduction(
12811411

12821412
return self.act(self.edge_angle_linear2(padding_edge_angle_update))
12831413

1284-
def _compute_node_edge_message(
1285-
self,
1286-
node_ebd: torch.Tensor,
1287-
node_ebd_ext: torch.Tensor,
1288-
edge_ebd: torch.Tensor,
1289-
nei_node_ebd: torch.Tensor,
1290-
sw: torch.Tensor,
1291-
nlist: torch.Tensor,
1292-
n2e_index: torch.Tensor,
1293-
n_ext2e_index: torch.Tensor,
1294-
nb: int,
1295-
nloc: int,
1296-
) -> torch.Tensor:
1297-
"""Compute node edge message and reduce over neighbor dimension."""
1298-
if not self.optim_update:
1299-
if not self.use_dynamic_sel:
1300-
edge_info = torch.cat(
1301-
[
1302-
torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]),
1303-
nei_node_ebd,
1304-
edge_ebd,
1305-
],
1306-
dim=-1,
1307-
)
1308-
else:
1309-
edge_info = torch.cat(
1310-
[
1311-
torch.index_select(
1312-
node_ebd.reshape(-1, self.n_dim), 0, n2e_index
1313-
),
1314-
nei_node_ebd,
1315-
edge_ebd,
1316-
],
1317-
dim=-1,
1318-
)
1319-
node_edge_update = self.act(
1320-
self.node_edge_linear(edge_info)
1321-
) * sw.unsqueeze(-1)
1322-
else:
1323-
node_edge_update = self.act(
1324-
self.optim_edge_update(
1325-
node_ebd,
1326-
node_ebd_ext,
1327-
edge_ebd,
1328-
nlist,
1329-
"node",
1330-
)
1331-
if not self.use_dynamic_sel
1332-
else self.optim_edge_update_dynamic(
1333-
node_ebd,
1334-
node_ebd_ext,
1335-
edge_ebd,
1336-
n2e_index,
1337-
n_ext2e_index,
1338-
"node",
1339-
)
1340-
) * sw.unsqueeze(-1)
1341-
1342-
node_edge_update = (
1343-
(torch.sum(node_edge_update, dim=-2) / self.nnei)
1344-
if not self.use_dynamic_sel
1345-
else (
1346-
aggregate(
1347-
node_edge_update,
1348-
n2e_index,
1349-
average=False,
1350-
num_owner=nb * nloc,
1351-
).reshape(nb, nloc, node_edge_update.shape[-1])
1352-
/ self.dynamic_e_sel
1353-
)
1354-
)
1355-
return node_edge_update
1356-
1357-
def _compute_node_sym(
1358-
self,
1359-
edge_ebd: torch.Tensor,
1360-
nei_node_ebd: torch.Tensor,
1361-
h2: torch.Tensor,
1362-
nlist_mask: torch.Tensor,
1363-
sw: torch.Tensor,
1364-
n2e_index: torch.Tensor,
1365-
nb: int,
1366-
nloc: int,
1367-
) -> torch.Tensor:
1368-
"""Compute node symmetrization update (grrg + drrd)."""
1369-
node_sym_list: list[torch.Tensor] = []
1370-
node_sym_list.append(
1371-
self.symmetrization_op(
1372-
edge_ebd,
1373-
h2,
1374-
nlist_mask,
1375-
sw,
1376-
self.axis_neuron,
1377-
)
1378-
if not self.use_dynamic_sel
1379-
else self.symmetrization_op_dynamic(
1380-
edge_ebd,
1381-
h2,
1382-
sw,
1383-
owner=n2e_index,
1384-
num_owner=nb * nloc,
1385-
nb=nb,
1386-
nloc=nloc,
1387-
scale_factor=self.dynamic_e_sel ** (-0.5),
1388-
axis_neuron=self.axis_neuron,
1389-
)
1390-
)
1391-
node_sym_list.append(
1392-
self.symmetrization_op(
1393-
nei_node_ebd,
1394-
h2,
1395-
nlist_mask,
1396-
sw,
1397-
self.axis_neuron,
1398-
)
1399-
if not self.use_dynamic_sel
1400-
else self.symmetrization_op_dynamic(
1401-
nei_node_ebd,
1402-
h2,
1403-
sw,
1404-
owner=n2e_index,
1405-
num_owner=nb * nloc,
1406-
nb=nb,
1407-
nloc=nloc,
1408-
scale_factor=self.dynamic_e_sel ** (-0.5),
1409-
axis_neuron=self.axis_neuron,
1410-
)
1411-
)
1412-
return self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1)))
1413-
14141414
@torch.jit.export
14151415
def list_update_res_avg(
14161416
self,

0 commit comments

Comments
 (0)