Skip to content

Commit 6ba2399

Browse files
committed
feat(pt): add rmsnorm for edge components
1 parent 6e11a36 commit 6ba2399

5 files changed

Lines changed: 72 additions & 4 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
angle_use_node: bool = True,
8080
angle_self_attention: bool = False,
8181
angle_self_attention_gate: str = "none",
82+
rmsnorm_mode: str = "none",
8283
) -> None:
8384
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
8485
@@ -206,6 +207,7 @@ def __init__(
206207
self.only_angle_gated_mlp = only_angle_gated_mlp
207208
self.angle_self_attention = angle_self_attention
208209
self.angle_self_attention_gate = angle_self_attention_gate
210+
self.rmsnorm_mode = rmsnorm_mode
209211
assert (
210212
fix_stat_std == 0.3
211213
), "fix_stat_std is not implemented in this version, please use skip_stat instead."

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def init_subclass_params(sub_data, sub_class):
216216
use_loc_mapping=use_loc_mapping,
217217
angle_self_attention=self.repflow_args.angle_self_attention,
218218
angle_self_attention_gate=self.repflow_args.angle_self_attention_gate,
219+
rmsnorm_mode=self.repflow_args.rmsnorm_mode,
219220
exclude_types=exclude_types,
220221
env_protection=env_protection,
221222
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
angle_use_node: bool = True,
9393
angle_self_attention: bool = False,
9494
angle_self_attention_gate: str = "none",
95+
rmsnorm_mode: str = "none",
9596
activation_function: str = "silu",
9697
update_style: str = "res_residual",
9798
update_residual: float = 0.1,
@@ -189,6 +190,40 @@ def __init__(
189190
else:
190191
self.node_rmsnorm = None
191192

193+
# add rms norm debug for each component, can be removed if not necessary
194+
self.rmsnorm_mode = rmsnorm_mode
195+
self.rmsnorm_mod_list = self.rmsnorm_mode.split(":")
196+
# mode: ['NEM', 'ESM', 'EAM', 'ASM', 'E']
197+
# node edge message
198+
if "NEM" in self.rmsnorm_mod_list:
199+
self.NEM_rmsnorm = RMSNorm(self.n_dim, precision=precision, trainable=True)
200+
else:
201+
self.NEM_rmsnorm = None
202+
203+
# edge self message
204+
if "ESM" in self.rmsnorm_mod_list:
205+
self.ESM_rmsnorm = RMSNorm(self.e_dim, precision=precision, trainable=True)
206+
else:
207+
self.ESM_rmsnorm = None
208+
209+
# edge angle message
210+
if "EAM" in self.rmsnorm_mod_list:
211+
self.EAM_rmsnorm = RMSNorm(self.e_dim, precision=precision, trainable=True)
212+
else:
213+
self.EAM_rmsnorm = None
214+
215+
# angle self message
216+
if "ASM" in self.rmsnorm_mod_list:
217+
self.ASM_rmsnorm = RMSNorm(self.a_dim, precision=precision, trainable=True)
218+
else:
219+
self.ASM_rmsnorm = None
220+
221+
# edge self
222+
if "E" in self.rmsnorm_mod_list:
223+
self.edge_rmsnorm = RMSNorm(self.e_dim, precision=precision, trainable=True)
224+
else:
225+
self.edge_rmsnorm = None
226+
192227
self.angle_use_node = angle_use_node
193228
self.angle_self_attention = angle_self_attention
194229
self.angle_self_attention_gate = angle_self_attention_gate
@@ -1320,6 +1355,10 @@ def forward(
13201355
)
13211356
)
13221357

1358+
if "NEM" in self.rmsnorm_mod_list:
1359+
assert self.NEM_rmsnorm is not None
1360+
node_edge_update = self.NEM_rmsnorm(node_edge_update)
1361+
13231362
if self.n_multi_edge_message > 1:
13241363
# nb x nloc x h x n_dim
13251364
node_edge_update_mul_head = node_edge_update.view(
@@ -1372,6 +1411,11 @@ def forward(
13721411
if self.edge_rbf_dot_message:
13731412
assert edge_rbf is not None
13741413
edge_self_update = edge_self_update * edge_rbf
1414+
1415+
if "ESM" in self.rmsnorm_mod_list:
1416+
assert self.ESM_rmsnorm is not None
1417+
edge_self_update = self.ESM_rmsnorm(edge_self_update)
1418+
13751419
e_update_list.append(edge_self_update)
13761420

13771421
# edge attention message
@@ -1561,11 +1605,15 @@ def forward(
15611605
)
15621606
if not self.use_slim_message:
15631607
assert self.edge_angle_linear2 is not None
1564-
e_update_list.append(
1565-
self.act(self.edge_angle_linear2(padding_edge_angle_update))
1608+
padding_edge_angle_update = self.act(
1609+
self.edge_angle_linear2(padding_edge_angle_update)
15661610
)
1567-
else:
1568-
e_update_list.append(padding_edge_angle_update)
1611+
1612+
if "EAM" in self.rmsnorm_mod_list:
1613+
assert self.EAM_rmsnorm is not None
1614+
padding_edge_angle_update = self.EAM_rmsnorm(padding_edge_angle_update)
1615+
1616+
e_update_list.append(padding_edge_angle_update)
15691617
# update edge_ebd
15701618
e_updated = self.list_update(e_update_list, "edge")
15711619

@@ -1600,6 +1648,10 @@ def forward(
16001648
"angle",
16011649
)
16021650
)
1651+
if "ASM" in self.rmsnorm_mod_list:
1652+
assert self.ASM_rmsnorm is not None
1653+
angle_self_update = self.ASM_rmsnorm(angle_self_update)
1654+
16031655
a_update_list.append(angle_self_update)
16041656

16051657
if self.angle_self_attention:
@@ -1828,6 +1880,10 @@ def list_update_res_residual(
18281880
if update_name == "node" and self.node_use_rmsnorm:
18291881
assert self.node_rmsnorm is not None
18301882
uu = self.node_rmsnorm(uu)
1883+
1884+
if update_name == "edge" and "E" in self.rmsnorm_mod_list:
1885+
assert self.edge_rmsnorm is not None
1886+
uu = self.edge_rmsnorm(uu)
18311887
return uu
18321888

18331889
@torch.jit.export

deepmd/pt/model/descriptor/repflows.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(
160160
optim_update: bool = True,
161161
angle_self_attention: bool = False,
162162
angle_self_attention_gate: str = "none",
163+
rmsnorm_mode: str = "none",
163164
seed: Optional[Union[int, list[int]]] = None,
164165
) -> None:
165166
r"""
@@ -404,6 +405,7 @@ def __init__(
404405
"edge",
405406
"edge_feat",
406407
], "angle_self_attention_gate must be 'none', 'edge' or 'edge_feat'"
408+
self.rmsnorm_mode = rmsnorm_mode
407409

408410
self.activation_function = activation_function
409411
self.update_style = update_style
@@ -517,6 +519,7 @@ def __init__(
517519
angle_use_node=self.angle_use_node,
518520
angle_self_attention=self.angle_self_attention,
519521
angle_self_attention_gate=self.angle_self_attention_gate,
522+
rmsnorm_mode=self.rmsnorm_mode,
520523
seed=child_seed(child_seed(seed, 1), ii),
521524
)
522525
)

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1881,6 +1881,12 @@ def dpa3_repflow_args():
18811881
optional=True,
18821882
default="none",
18831883
),
1884+
Argument(
1885+
"rmsnorm_mode",
1886+
str,
1887+
optional=True,
1888+
default="none",
1889+
),
18841890
]
18851891

18861892

0 commit comments

Comments
 (0)