Skip to content

Commit 6e11a36

Browse files
committed
feat(pt): add angle gated attention
1 parent 32267af commit 6e11a36

5 files changed

Lines changed: 112 additions & 0 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def __init__(
7777
rk_order: int = 4,
7878
rk_update_diff_layer: bool = False,
7979
angle_use_node: bool = True,
80+
angle_self_attention: bool = False,
81+
angle_self_attention_gate: str = "none",
8082
) -> None:
8183
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
8284
@@ -202,6 +204,8 @@ def __init__(
202204
self.rk_update_diff_layer = rk_update_diff_layer
203205
self.angle_use_node = angle_use_node
204206
self.only_angle_gated_mlp = only_angle_gated_mlp
207+
self.angle_self_attention = angle_self_attention
208+
self.angle_self_attention_gate = angle_self_attention_gate
205209
assert (
206210
fix_stat_std == 0.3
207211
), "fix_stat_std is not implemented in this version, please use skip_stat instead."

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ def init_subclass_params(sub_data, sub_class):
214214
rk_update_diff_layer=self.repflow_args.rk_update_diff_layer,
215215
angle_use_node=self.repflow_args.angle_use_node,
216216
use_loc_mapping=use_loc_mapping,
217+
angle_self_attention=self.repflow_args.angle_self_attention,
218+
angle_self_attention_gate=self.repflow_args.angle_self_attention_gate,
217219
exclude_types=exclude_types,
218220
env_protection=env_protection,
219221
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def __init__(
9090
only_angle_gated_mlp: bool = False,
9191
node_use_rmsnorm: bool = False,
9292
angle_use_node: bool = True,
93+
angle_self_attention: bool = False,
94+
angle_self_attention_gate: str = "none",
9395
activation_function: str = "silu",
9496
update_style: str = "res_residual",
9597
update_residual: float = 0.1,
@@ -188,6 +190,8 @@ def __init__(
188190
self.node_rmsnorm = None
189191

190192
self.angle_use_node = angle_use_node
193+
self.angle_self_attention = angle_self_attention
194+
self.angle_self_attention_gate = angle_self_attention_gate
191195

192196
if self.edge_rbf_dot_self or self.edge_rbf_dot_message:
193197
self.rbf_mlp = MLPLayer(
@@ -501,6 +505,23 @@ def __init__(
501505
)
502506
)
503507

508+
if self.angle_self_attention:
509+
self.angle_attention_mlp_in = MLPLayer(
510+
self.a_dim,
511+
self.a_dim * 3, # query, key, value
512+
precision=precision,
513+
seed=child_seed(seed, 21),
514+
)
515+
self.angle_attention_mlp_out = MLPLayer(
516+
self.a_dim,
517+
self.a_dim,
518+
precision=precision,
519+
seed=child_seed(seed, 22),
520+
)
521+
else:
522+
self.angle_attention_mlp_in = None
523+
self.angle_attention_mlp_out = None
524+
504525
if self.update_dihedral:
505526
self.dihedral_dim = self.d_dim + 2 * self.a_dim
506527
# angle dihedral message
@@ -1581,6 +1602,63 @@ def forward(
15811602
)
15821603
a_update_list.append(angle_self_update)
15831604

1605+
if self.angle_self_attention:
1606+
# add a self-attention mechanism for angle_ebd with shape [nb x nloc x a_nnei x a_nnei x a_dim], on the last two dimensions
1607+
assert self.angle_attention_mlp_in is not None
1608+
assert self.angle_attention_mlp_out is not None
1609+
# nb x nloc x a_nnei x a_nnei x (3 * a_dim)
1610+
attention_output = self.angle_attention_mlp_in(angle_ebd)
1611+
# nb x nloc x a_nnei x a_nnei x a_dim
1612+
query, key, value = torch.chunk(
1613+
attention_output, 3, dim=-1
1614+
) # Split into query, key, value
1615+
# nb x nloc x a_nnei x a_nnei x a_nnei
1616+
attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (
1617+
query.size(-1) ** 0.5
1618+
) # Scaled dot-product attention
1619+
# smooth
1620+
attention_scores = (attention_scores + 20.0) * a_sw[
1621+
:, :, None, :, None
1622+
] * a_sw[:, :, None, None, :] - 20.0
1623+
# nb x nloc x a_nnei x a_nnei x a_nnei
1624+
attention_weights = torch.softmax(
1625+
attention_scores, dim=-1
1626+
) # Normalize scores
1627+
# smooth
1628+
attention_weights = (
1629+
attention_weights
1630+
* a_sw[:, :, None, :, None]
1631+
* a_sw[:, :, None, None, :]
1632+
)
1633+
# optional gates
1634+
if self.angle_self_attention_gate == "edge":
1635+
# nb x nloc x a_nnei x 3
1636+
h2_angle = h2[..., : self.a_sel, :]
1637+
# normalize
1638+
h2_angle = h2_angle / torch.linalg.norm(
1639+
h2_angle, dim=-1, keepdim=True
1640+
)
1641+
# nb x nloc x a_nnei x 3
1642+
h2_angle = torch.where(
1643+
a_nlist_mask.unsqueeze(-1).expand([-1, -1, -1, 3]),
1644+
h2_angle,
1645+
0.0,
1646+
)
1647+
# nb x nloc x a_nnei x a_nnei
1648+
h2h2t = torch.matmul(h2_angle, torch.transpose(h2_angle, -1, -2))
1649+
# nb x nloc x a_nnei x a_nnei x a_nnei
1650+
attention_weights = attention_weights * h2h2t[:, :, None, :, :]
1651+
1652+
# nb x nloc x a_nnei x a_nnei x a_dim
1653+
angle_ebd_attended = torch.matmul(
1654+
attention_weights, value
1655+
) # Apply attention weights to value
1656+
# nb x nloc x a_nnei x a_nnei x a_dim
1657+
angle_attention_updated = self.act(
1658+
self.angle_attention_mlp_out(angle_ebd_attended)
1659+
) # Apply attention output layer
1660+
a_update_list.append(angle_attention_updated)
1661+
15841662
# dihedral update with fixed sel
15851663
if self.update_dihedral and not self.use_dynamic_sel:
15861664
assert d_nlist is not None

deepmd/pt/model/descriptor/repflows.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def __init__(
158158
rk_update_diff_layer: bool = False,
159159
angle_use_node: bool = True,
160160
optim_update: bool = True,
161+
angle_self_attention: bool = False,
162+
angle_self_attention_gate: str = "none",
161163
seed: Optional[Union[int, list[int]]] = None,
162164
) -> None:
163165
r"""
@@ -391,6 +393,18 @@ def __init__(
391393
else:
392394
self.env = None
393395

396+
self.angle_self_attention = angle_self_attention
397+
self.angle_self_attention_gate = angle_self_attention_gate
398+
if self.angle_self_attention:
399+
assert (
400+
not self.use_dynamic_sel
401+
), "angle_self_attention does not support dynamic selection so far"
402+
assert self.angle_self_attention_gate in [
403+
"none",
404+
"edge",
405+
"edge_feat",
406+
], "angle_self_attention_gate must be 'none', 'edge' or 'edge_feat'"
407+
394408
self.activation_function = activation_function
395409
self.update_style = update_style
396410
self.update_residual = update_residual
@@ -501,6 +515,8 @@ def __init__(
501515
only_angle_gated_mlp=self.only_angle_gated_mlp,
502516
node_use_rmsnorm=self.node_use_rmsnorm,
503517
angle_use_node=self.angle_use_node,
518+
angle_self_attention=self.angle_self_attention,
519+
angle_self_attention_gate=self.angle_self_attention_gate,
504520
seed=child_seed(child_seed(seed, 1), ii),
505521
)
506522
)

deepmd/utils/argcheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1869,6 +1869,18 @@ def dpa3_repflow_args():
18691869
optional=True,
18701870
default=True,
18711871
),
1872+
Argument(
1873+
"angle_self_attention",
1874+
bool,
1875+
optional=True,
1876+
default=False,
1877+
),
1878+
Argument(
1879+
"angle_self_attention_gate",
1880+
str,
1881+
optional=True,
1882+
default="none",
1883+
),
18721884
]
18731885

18741886

0 commit comments

Comments
 (0)