Skip to content

Commit 12ab9b9

Browse files
committed
add update_use_layernorm
1 parent 71ed3da commit 12ab9b9

5 files changed

Lines changed: 33 additions & 0 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(
180180
use_exp_switch: bool = False,
181181
use_dynamic_sel: bool = False,
182182
sel_reduce_factor: float = 10.0,
183+
update_use_layernorm: bool = False,
183184
) -> None:
184185
self.n_dim = n_dim
185186
self.e_dim = e_dim
@@ -210,6 +211,7 @@ def __init__(
210211
self.use_exp_switch = use_exp_switch
211212
self.use_dynamic_sel = use_dynamic_sel
212213
self.sel_reduce_factor = sel_reduce_factor
214+
self.update_use_layernorm = update_use_layernorm
213215

214216
def __getitem__(self, key: str) -> Any:
215217
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def init_subclass_params(sub_data: Any, sub_class: Any) -> Any:
166166
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
167167
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
168168
use_loc_mapping=use_loc_mapping,
169+
update_use_layernorm=self.repflow_args.update_use_layernorm,
169170
exclude_types=exclude_types,
170171
env_protection=env_protection,
171172
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
use_dynamic_sel: bool = False,
5959
sel_reduce_factor: float = 10.0,
6060
smooth_edge_update: bool = False,
61+
update_use_layernorm: bool = False,
6162
activation_function: str = "silu",
6263
update_style: str = "res_residual",
6364
update_residual: float = 0.1,
@@ -97,6 +98,7 @@ def __init__(
9798
self.update_style = update_style
9899
self.update_residual = update_residual
99100
self.update_residual_init = update_residual_init
101+
self.update_use_layernorm = update_use_layernorm
100102
self.a_compress_e_rate = a_compress_e_rate
101103
self.a_compress_use_split = a_compress_use_split
102104
self.precision = precision
@@ -203,6 +205,17 @@ def __init__(
203205
)
204206
)
205207

208+
if self.update_use_layernorm:
209+
self.node_layernorm = torch.nn.LayerNorm(self.n_dim)
210+
self.edge_layernorm = torch.nn.LayerNorm(self.e_dim)
211+
self.angle_layernorm = (
212+
torch.nn.LayerNorm(self.a_dim) if self.update_angle else None
213+
)
214+
else:
215+
self.node_layernorm = None
216+
self.edge_layernorm = None
217+
self.angle_layernorm = None
218+
206219
if self.update_angle:
207220
self.angle_dim = self.a_dim
208221
if self.a_compress_rate == 0:
@@ -1133,6 +1146,14 @@ def forward(
11331146

11341147
# update angle_ebd
11351148
a_updated = self.list_update(a_update_list, "angle")
1149+
if self.update_use_layernorm:
1150+
assert self.node_layernorm is not None
1151+
n_updated = self.node_layernorm(n_updated)
1152+
assert self.edge_layernorm is not None
1153+
e_updated = self.edge_layernorm(e_updated)
1154+
if self.update_angle:
1155+
assert self.angle_layernorm is not None
1156+
a_updated = self.angle_layernorm(a_updated)
11361157
return n_updated, e_updated, a_updated
11371158

11381159
@torch.jit.export

deepmd/pt/model/descriptor/repflows.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def __init__(
220220
use_dynamic_sel: bool = False,
221221
sel_reduce_factor: float = 10.0,
222222
use_loc_mapping: bool = True,
223+
update_use_layernorm: bool = False,
223224
optim_update: bool = True,
224225
seed: Optional[Union[int, list[int]]] = None,
225226
trainable: bool = True,
@@ -285,6 +286,7 @@ def __init__(
285286
self.precision = precision
286287
self.epsilon = 1e-4
287288
self.seed = seed
289+
self.update_use_layernorm = update_use_layernorm
288290

289291
self.edge_embd = MLPLayer(
290292
1,
@@ -330,6 +332,7 @@ def __init__(
330332
use_dynamic_sel=self.use_dynamic_sel,
331333
sel_reduce_factor=self.sel_reduce_factor,
332334
smooth_edge_update=self.smooth_edge_update,
335+
update_use_layernorm=self.update_use_layernorm,
333336
seed=child_seed(child_seed(seed, 1), ii),
334337
trainable=trainable,
335338
)

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,6 +1667,12 @@ def dpa3_repflow_args() -> list[Argument]:
16671667
default=10.0,
16681668
doc=doc_sel_reduce_factor,
16691669
),
1670+
Argument(
1671+
"update_use_layernorm",
1672+
bool,
1673+
optional=True,
1674+
default=False,
1675+
),
16701676
]
16711677

16721678

0 commit comments

Comments
 (0)