Skip to content

Commit ff117da

Browse files
committed
Update repflow_layer.py
1 parent 9822484 commit ff117da

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,12 @@ def __init__(
395395
precision=precision,
396396
seed=child_seed(seed, 21),
397397
)
398+
self.node_nga_mlp_out = MLPLayer(
399+
n_dim,
400+
n_dim,
401+
precision=precision,
402+
seed=child_seed(seed, 22),
403+
)
398404
if self.update_style == "res_residual":
399405
self.n_residual.append(
400406
get_residual(
@@ -409,6 +415,7 @@ def __init__(
409415
else:
410416
self.edge_nga_mlp = None
411417
self.node_nga_mlp = None
418+
self.node_nga_mlp_out = None
412419

413420
# edge self message
414421
if not self.use_gated_mlp or self.only_angle_gated_mlp:
@@ -1451,6 +1458,7 @@ def forward(
14511458
if self.EN_use_NGA:
14521459
assert self.node_nga_mlp is not None
14531460
assert self.edge_nga_mlp is not None
1461+
assert self.node_nga_mlp_out is not None
14541462
assert edge_cat_list is not None
14551463
# nb, nloc, nnei, n_dim
14561464
attention_weights_nga_i = self.edge_nga_mlp(edge_cat_list[2])
@@ -1459,13 +1467,17 @@ def forward(
14591467
) - 20.0
14601468
attention_weights_nga_i = torch.softmax(attention_weights_nga_i, dim=-2)
14611469
# nb, nloc, nnei, n_dim
1462-
attention_value_nga = self.node_nga_mlp(
1463-
torch.cat(edge_cat_list[:2], dim=-1)
1470+
attention_value_nga = self.act(
1471+
self.node_nga_mlp(torch.cat(edge_cat_list[:2], dim=-1))
14641472
)
14651473

14661474
# updated value
14671475
# nb, nloc, n_dim
1468-
update_node_nga = (attention_weights_nga_i * attention_value_nga).sum(-2)
1476+
update_node_nga = self.act(
1477+
self.node_nga_mlp_out(
1478+
(attention_weights_nga_i * attention_value_nga).sum(-2)
1479+
)
1480+
)
14691481
n_update_list.append(update_node_nga)
14701482

14711483
# update node_ebd

0 commit comments

Comments
 (0)