Skip to content

Commit 9b5e152

Browse files
committed
add NGA edge
1 parent 983e94b commit 9b5e152

5 files changed

Lines changed: 67 additions & 0 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
angle_sh_init_lmax: int = 3,
8989
angle_use_fixed_gaussian: bool = False,
9090
angle_fixed_gaussian_interpolate: bool = False,
91+
EN_use_NGA: bool = False,
9192
) -> None:
9293
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
9394
@@ -224,6 +225,7 @@ def __init__(
224225
self.angle_sh_init_lmax = angle_sh_init_lmax
225226
self.angle_use_fixed_gaussian = angle_use_fixed_gaussian
226227
self.angle_fixed_gaussian_interpolate = angle_fixed_gaussian_interpolate
228+
self.EN_use_NGA = EN_use_NGA
227229
assert (
228230
fix_stat_std == 0.3
229231
), "fix_stat_std is not implemented in this version, please use skip_stat instead."

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def init_subclass_params(sub_data, sub_class):
227227
angle_sh_init_lmax=self.repflow_args.angle_sh_init_lmax,
228228
angle_use_fixed_gaussian=self.repflow_args.angle_use_fixed_gaussian,
229229
angle_fixed_gaussian_interpolate=self.repflow_args.angle_fixed_gaussian_interpolate,
230+
EN_use_NGA=self.repflow_args.EN_use_NGA,
230231
exclude_types=exclude_types,
231232
env_protection=env_protection,
232233
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
edge_rbf_cat_message: bool = False,
9797
edge_message_use_dropout: bool = False,
9898
angle_message_use_dropout: bool = False,
99+
EN_use_NGA: bool = False,
99100
dropout_rate: float = 0.1,
100101
activation_function: str = "silu",
101102
update_style: str = "res_residual",
@@ -146,6 +147,9 @@ def __init__(
146147
self.sel_reduce_factor = sel_reduce_factor
147148
self.dynamic_e_sel = self.nnei / self.sel_reduce_factor
148149
self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor
150+
self.EN_use_NGA = EN_use_NGA
151+
if self.EN_use_NGA:
152+
assert not self.use_dynamic_sel, "NGA does not support dynamic selection!"
149153

150154
self.update_dihedral = update_dihedral
151155
self.d_dim = d_dim
@@ -375,6 +379,35 @@ def __init__(
375379
)
376380
residual_idx += 1
377381

382+
# node edge NGA
383+
if self.EN_use_NGA:
384+
self.edge_nga_mlp = MLPLayer(
385+
e_dim,
386+
n_dim,
387+
precision=precision,
388+
seed=child_seed(seed, 20),
389+
)
390+
self.node_nga_mlp = MLPLayer(
391+
2 * n_dim,
392+
n_dim,
393+
precision=precision,
394+
seed=child_seed(seed, 21),
395+
)
396+
if self.update_style == "res_residual":
397+
self.n_residual.append(
398+
get_residual(
399+
n_dim,
400+
self.update_residual * self.residual_pref[residual_idx],
401+
self.update_residual_init,
402+
precision=precision,
403+
seed=child_seed(seed, 22),
404+
)
405+
)
406+
residual_idx += 1
407+
else:
408+
self.edge_nga_mlp = None
409+
self.node_nga_mlp = None
410+
378411
# edge self message
379412
if not self.use_gated_mlp or self.only_angle_gated_mlp:
380413
self.edge_self_linear = MLPLayer(
@@ -1330,6 +1363,7 @@ def forward(
13301363
else:
13311364
edge_info = None
13321365
edge_info_ffn = None
1366+
edge_cat_list = None
13331367

13341368
# edge message use dropout
13351369
if self.edge_message_use_dropout:
@@ -1411,6 +1445,27 @@ def forward(
14111445
)
14121446
)
14131447
n_update_list.append(node_edge_update)
1448+
1449+
if self.EN_use_NGA:
1450+
assert self.node_nga_mlp is not None
1451+
assert self.edge_nga_mlp is not None
1452+
assert edge_cat_list is not None
1453+
# nb, nloc, nnei, n_dim
1454+
attention_weights_nga_i = self.edge_nga_mlp(edge_cat_list[2])
1455+
attention_weights_nga_i = (attention_weights_nga_i + 20.0) * sw.unsqueeze(
1456+
-1
1457+
) - 20.0
1458+
attention_weights_nga_i = torch.softmax(attention_weights_nga_i, dim=-2)
1459+
# nb, nloc, nnei, n_dim
1460+
attention_value_nga = self.node_nga_mlp(
1461+
torch.cat(edge_cat_list[:2], dim=-1)
1462+
)
1463+
1464+
# updated value
1465+
# nb, nloc, n_dim
1466+
update_node_nga = (attention_weights_nga_i * attention_value_nga).sum(-2)
1467+
n_update_list.append(update_node_nga)
1468+
14141469
# update node_ebd
14151470
n_updated = self.list_update(n_update_list, "node")
14161471

deepmd/pt/model/descriptor/repflows.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def __init__(
171171
angle_sh_init_lmax: int = 3,
172172
angle_use_fixed_gaussian: bool = False,
173173
angle_fixed_gaussian_interpolate: bool = False,
174+
EN_use_NGA: bool = False,
174175
seed: Optional[Union[int, list[int]]] = None,
175176
) -> None:
176177
r"""
@@ -316,6 +317,7 @@ def __init__(
316317
)
317318
else:
318319
self.angle_gaussian_encoder = None
320+
self.EN_use_NGA = EN_use_NGA
319321
self.use_env_envelope = use_env_envelope
320322
self.use_new_sw = use_new_sw
321323
self.use_force_embedding = use_force_embedding
@@ -582,6 +584,7 @@ def __init__(
582584
edge_rbf_cat_message=self.edge_rbf_cat_message,
583585
edge_message_use_dropout=self.edge_message_use_dropout,
584586
angle_message_use_dropout=self.angle_message_use_dropout,
587+
EN_use_NGA=self.EN_use_NGA,
585588
dropout_rate=self.dropout_rate,
586589
seed=child_seed(child_seed(seed, 1), ii),
587590
)

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,6 +1941,12 @@ def dpa3_repflow_args():
19411941
optional=True,
19421942
default=0.1,
19431943
),
1944+
Argument(
1945+
"EN_use_NGA",
1946+
bool,
1947+
optional=True,
1948+
default=False,
1949+
),
19441950
]
19451951

19461952

0 commit comments

Comments
 (0)