Skip to content

Commit b062d76

Browse files
committed
add angle NGAi
1 parent ff117da commit b062d76

5 files changed

Lines changed: 91 additions & 1 deletion

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
angle_use_fixed_gaussian: bool = False,
9090
angle_fixed_gaussian_interpolate: bool = False,
9191
EN_use_NGA: bool = False,
92+
AE_use_NGA: bool = False,
9293
) -> None:
9394
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
9495
@@ -226,6 +227,7 @@ def __init__(
226227
self.angle_use_fixed_gaussian = angle_use_fixed_gaussian
227228
self.angle_fixed_gaussian_interpolate = angle_fixed_gaussian_interpolate
228229
self.EN_use_NGA = EN_use_NGA
230+
self.AE_use_NGA = AE_use_NGA
229231
assert (
230232
fix_stat_std == 0.3
231233
), "fix_stat_std is not implemented in this version, please use skip_stat instead."

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def init_subclass_params(sub_data, sub_class):
228228
angle_use_fixed_gaussian=self.repflow_args.angle_use_fixed_gaussian,
229229
angle_fixed_gaussian_interpolate=self.repflow_args.angle_fixed_gaussian_interpolate,
230230
EN_use_NGA=self.repflow_args.EN_use_NGA,
231+
AE_use_NGA=self.repflow_args.AE_use_NGA,
231232
exclude_types=exclude_types,
232233
env_protection=env_protection,
233234
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
edge_message_use_dropout: bool = False,
9898
angle_message_use_dropout: bool = False,
9999
EN_use_NGA: bool = False,
100+
AE_use_NGA: bool = False,
100101
dropout_rate: float = 0.1,
101102
activation_function: str = "silu",
102103
update_style: str = "res_residual",
@@ -148,7 +149,8 @@ def __init__(
148149
self.dynamic_e_sel = self.nnei / self.sel_reduce_factor
149150
self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor
150151
self.EN_use_NGA = EN_use_NGA
151-
if self.EN_use_NGA:
152+
self.AE_use_NGA = AE_use_NGA
153+
if self.EN_use_NGA or self.AE_use_NGA:
152154
assert (
153155
not self.use_dynamic_sel and not self.optim_update
154156
), "NGA does not support dynamic selection or optim update!"
@@ -575,6 +577,42 @@ def __init__(
575577
)
576578
residual_idx += 1
577579

580+
# edge angle NGA
581+
if self.AE_use_NGA:
582+
self.AE_angle_nga_mlp = MLPLayer(
583+
a_dim,
584+
e_dim,
585+
precision=precision,
586+
seed=child_seed(seed, 24),
587+
)
588+
self.AE_edge_nga_mlp = MLPLayer(
589+
2 * self.e_a_compress_dim,
590+
e_dim,
591+
precision=precision,
592+
seed=child_seed(seed, 25),
593+
)
594+
self.AE_edge_nga_mlp_out = MLPLayer(
595+
e_dim,
596+
e_dim,
597+
precision=precision,
598+
seed=child_seed(seed, 26),
599+
)
600+
if self.update_style == "res_residual":
601+
self.e_residual.append(
602+
get_residual(
603+
e_dim,
604+
self.update_residual * self.residual_pref[residual_idx],
605+
self.update_residual_init,
606+
precision=precision,
607+
seed=child_seed(seed, 27),
608+
)
609+
)
610+
residual_idx += 1
611+
else:
612+
self.AE_angle_nga_mlp = None
613+
self.AE_edge_nga_mlp = None
614+
self.AE_edge_nga_mlp_out = None
615+
578616
# angle self message
579617
if not self.use_gated_mlp:
580618
self.angle_self_linear = MLPLayer(
@@ -1455,6 +1493,7 @@ def forward(
14551493
)
14561494
n_update_list.append(node_edge_update)
14571495

1496+
# node edge nga
14581497
if self.EN_use_NGA:
14591498
assert self.node_nga_mlp is not None
14601499
assert self.edge_nga_mlp is not None
@@ -1613,6 +1652,7 @@ def forward(
16131652
else:
16141653
angle_info = None
16151654
angle_info_ffn = None
1655+
angle_info_list = None
16161656

16171657
# angle message use dropout
16181658
if self.angle_message_use_dropout:
@@ -1725,6 +1765,44 @@ def forward(
17251765
padding_edge_angle_update = self.EAM_rmsnorm(padding_edge_angle_update)
17261766

17271767
e_update_list.append(padding_edge_angle_update)
1768+
1769+
# edge angle NGA
1770+
if self.AE_use_NGA:
1771+
assert self.AE_angle_nga_mlp is not None
1772+
assert self.AE_edge_nga_mlp is not None
1773+
assert self.AE_edge_nga_mlp_out is not None
1774+
assert angle_info_list is not None
1775+
1776+
# nb, nloc, a_nnei, a_nnei, e_dim
1777+
attention_weights_nga_i = self.AE_angle_nga_mlp(angle_info_list[0])
1778+
attention_weights_nga_i = (
1779+
attention_weights_nga_i + 20.0
1780+
) * a_sw.unsqueeze(-1).unsqueeze(-1) - 20.0
1781+
attention_weights_nga_i = torch.softmax(attention_weights_nga_i, dim=-2)
1782+
# nb, nloc, a_nnei, a_nnei, e_dim
1783+
attention_value_nga = self.act(self.AE_edge_nga_mlp(angle_info_list[2]))
1784+
1785+
# updated value
1786+
# nb, nloc, a_nnei, e_dim
1787+
reduce_edge_nga = (attention_weights_nga_i * attention_value_nga).sum(
1788+
-2
1789+
)
1790+
# nb x nloc x nnei x e_dim
1791+
padding_edge_nga = torch.concat(
1792+
[
1793+
reduce_edge_nga,
1794+
torch.zeros(
1795+
[nb, nloc, self.nnei - self.a_sel, self.e_dim],
1796+
dtype=edge_ebd.dtype,
1797+
device=edge_ebd.device,
1798+
),
1799+
],
1800+
dim=2,
1801+
)
1802+
# nb x nloc x nnei x e_dim
1803+
update_edge_nga = self.act(self.AE_edge_nga_mlp_out(padding_edge_nga))
1804+
e_update_list.append(update_edge_nga)
1805+
17281806
# update edge_ebd
17291807
e_updated = self.list_update(e_update_list, "edge")
17301808

deepmd/pt/model/descriptor/repflows.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def __init__(
172172
angle_use_fixed_gaussian: bool = False,
173173
angle_fixed_gaussian_interpolate: bool = False,
174174
EN_use_NGA: bool = False,
175+
AE_use_NGA: bool = False,
175176
seed: Optional[Union[int, list[int]]] = None,
176177
) -> None:
177178
r"""
@@ -318,6 +319,7 @@ def __init__(
318319
else:
319320
self.angle_gaussian_encoder = None
320321
self.EN_use_NGA = EN_use_NGA
322+
self.AE_use_NGA = AE_use_NGA
321323
self.use_env_envelope = use_env_envelope
322324
self.use_new_sw = use_new_sw
323325
self.use_force_embedding = use_force_embedding
@@ -585,6 +587,7 @@ def __init__(
585587
edge_message_use_dropout=self.edge_message_use_dropout,
586588
angle_message_use_dropout=self.angle_message_use_dropout,
587589
EN_use_NGA=self.EN_use_NGA,
590+
AE_use_NGA=self.AE_use_NGA,
588591
dropout_rate=self.dropout_rate,
589592
seed=child_seed(child_seed(seed, 1), ii),
590593
)

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,6 +1947,12 @@ def dpa3_repflow_args():
19471947
optional=True,
19481948
default=False,
19491949
),
1950+
Argument(
1951+
"AE_use_NGA",
1952+
bool,
1953+
optional=True,
1954+
default=False,
1955+
),
19501956
]
19511957

19521958

0 commit comments

Comments
 (0)