Skip to content

Commit 28d688c

Browse files
committed
add use_gated_mlp and gated_mlp_norm
1 parent 12ab9b9 commit 28d688c

6 files changed

Lines changed: 149 additions & 14 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ def __init__(
181181
use_dynamic_sel: bool = False,
182182
sel_reduce_factor: float = 10.0,
183183
update_use_layernorm: bool = False,
184+
use_gated_mlp: bool = False,
185+
gated_mlp_norm: str = "none",
184186
) -> None:
185187
self.n_dim = n_dim
186188
self.e_dim = e_dim
@@ -212,6 +214,8 @@ def __init__(
212214
self.use_dynamic_sel = use_dynamic_sel
213215
self.sel_reduce_factor = sel_reduce_factor
214216
self.update_use_layernorm = update_use_layernorm
217+
self.use_gated_mlp = use_gated_mlp
218+
self.gated_mlp_norm = gated_mlp_norm
215219

216220
def __getitem__(self, key: str) -> Any:
217221
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,10 @@ def init_subclass_params(sub_data: Any, sub_class: Any) -> Any:
166166
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
167167
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
168168
use_loc_mapping=use_loc_mapping,
169+
# followings are new added param
169170
update_use_layernorm=self.repflow_args.update_use_layernorm,
171+
use_gated_mlp=self.repflow_args.use_gated_mlp,
172+
gated_mlp_norm=self.repflow_args.gated_mlp_norm,
170173
exclude_types=exclude_types,
171174
env_protection=env_protection,
172175
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_residual,
1818
)
1919
from deepmd.pt.model.network.mlp import (
20+
GatedMLP,
2021
MLPLayer,
2122
)
2223
from deepmd.pt.model.network.utils import (
@@ -59,6 +60,8 @@ def __init__(
5960
sel_reduce_factor: float = 10.0,
6061
smooth_edge_update: bool = False,
6162
update_use_layernorm: bool = False,
63+
use_gated_mlp: bool = False,
64+
gated_mlp_norm: str = "none",
6265
activation_function: str = "silu",
6366
update_style: str = "res_residual",
6467
update_residual: float = 0.1,
@@ -99,6 +102,10 @@ def __init__(
99102
self.update_residual = update_residual
100103
self.update_residual_init = update_residual_init
101104
self.update_use_layernorm = update_use_layernorm
105+
self.use_gated_mlp = use_gated_mlp
106+
if self.use_gated_mlp:
107+
assert not optim_update, "Gated MLP does not support optim update!"
108+
self.gated_mlp_norm = gated_mlp_norm
102109
self.a_compress_e_rate = a_compress_e_rate
103110
self.a_compress_use_split = a_compress_use_split
104111
self.precision = precision
@@ -165,13 +172,23 @@ def __init__(
165172
)
166173

167174
# node edge message
168-
self.node_edge_linear = MLPLayer(
169-
self.edge_info_dim,
170-
self.n_multi_edge_message * n_dim,
171-
precision=precision,
172-
seed=child_seed(seed, 4),
173-
trainable=trainable,
174-
)
175+
if not self.use_gated_mlp:
176+
self.node_edge_linear = MLPLayer(
177+
self.edge_info_dim,
178+
self.n_multi_edge_message * n_dim,
179+
precision=precision,
180+
seed=child_seed(seed, 4),
181+
trainable=trainable,
182+
)
183+
else:
184+
self.node_edge_linear = GatedMLP(
185+
self.edge_info_dim,
186+
self.n_multi_edge_message * n_dim,
187+
activation_function=self.activation_function,
188+
norm=self.gated_mlp_norm,
189+
precision=precision,
190+
seed=child_seed(seed, 4),
191+
)
175192
if self.update_style == "res_residual":
176193
for head_index in range(self.n_multi_edge_message):
177194
self.n_residual.append(
@@ -256,13 +273,23 @@ def __init__(
256273
self.a_compress_e_linear = None
257274

258275
# edge angle message
259-
self.edge_angle_linear1 = MLPLayer(
260-
self.angle_dim,
261-
self.e_dim,
262-
precision=precision,
263-
seed=child_seed(seed, 10),
264-
trainable=trainable,
265-
)
276+
if not self.use_gated_mlp:
277+
self.edge_angle_linear1 = MLPLayer(
278+
self.angle_dim,
279+
self.e_dim,
280+
precision=precision,
281+
seed=child_seed(seed, 10),
282+
trainable=trainable,
283+
)
284+
else:
285+
self.edge_angle_linear1 = GatedMLP(
286+
self.angle_dim,
287+
self.e_dim,
288+
activation_function=self.activation_function,
289+
norm=self.gated_mlp_norm,
290+
precision=precision,
291+
seed=child_seed(seed, 10),
292+
)
266293
self.edge_angle_linear2 = MLPLayer(
267294
self.e_dim,
268295
self.e_dim,

deepmd/pt/model/descriptor/repflows.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def __init__(
221221
sel_reduce_factor: float = 10.0,
222222
use_loc_mapping: bool = True,
223223
update_use_layernorm: bool = False,
224+
use_gated_mlp: bool = False,
225+
gated_mlp_norm: str = "none",
224226
optim_update: bool = True,
225227
seed: Optional[Union[int, list[int]]] = None,
226228
trainable: bool = True,
@@ -287,6 +289,8 @@ def __init__(
287289
self.epsilon = 1e-4
288290
self.seed = seed
289291
self.update_use_layernorm = update_use_layernorm
292+
self.use_gated_mlp = use_gated_mlp
293+
self.gated_mlp_norm = gated_mlp_norm
290294

291295
self.edge_embd = MLPLayer(
292296
1,
@@ -333,6 +337,8 @@ def __init__(
333337
sel_reduce_factor=self.sel_reduce_factor,
334338
smooth_edge_update=self.smooth_edge_update,
335339
update_use_layernorm=self.update_use_layernorm,
340+
use_gated_mlp=self.use_gated_mlp,
341+
gated_mlp_norm=self.gated_mlp_norm,
336342
seed=child_seed(child_seed(seed, 1), ii),
337343
trainable=trainable,
338344
)

deepmd/pt/model/network/mlp.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,89 @@ def check_load_param(ss: str) -> Optional[nn.Parameter]:
280280
return obj
281281

282282

283+
class GatedMLP(nn.Module):
284+
"""Gated MLP
285+
similar model structure is used in CGCNN and M3GNet.
286+
"""
287+
288+
def __init__(
289+
self,
290+
input_dim: int,
291+
output_dim: int,
292+
*,
293+
activation_function: Optional[str] = None,
294+
norm: str = "batch",
295+
bias: bool = True,
296+
precision: str = DEFAULT_PRECISION,
297+
seed: Optional[Union[int, list[int]]] = None,
298+
) -> None:
299+
"""Initialize a gated MLP.
300+
301+
Args:
302+
input_dim (int): the input dimension
303+
output_dim (int): the output dimension
304+
activation_function (str, optional): The name of the activation function to use in
305+
the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu".
306+
Default = "silu"
307+
norm (str, optional): The name of the normalization layer to use on the
308+
updated atom features. Must be one of "batch", "layer", or None.
309+
Default = "batch"
310+
bias (bool): whether to use bias in each Linear layers.
311+
Default = True
312+
"""
313+
super().__init__()
314+
self.mlp_core = MLPLayer(
315+
input_dim,
316+
output_dim,
317+
bias=bias,
318+
precision=precision,
319+
seed=seed,
320+
)
321+
self.mlp_gate = MLPLayer(
322+
input_dim,
323+
output_dim,
324+
bias=bias,
325+
precision=precision,
326+
seed=seed,
327+
)
328+
# for jit
329+
self.matrix = self.mlp_core.matrix
330+
self.bias = self.mlp_core.bias
331+
self.act = ActivationFn(activation_function)
332+
self.sigmoid = nn.Sigmoid()
333+
self.norm1 = find_normalization(name=norm, dim=output_dim)
334+
self.norm2 = find_normalization(name=norm, dim=output_dim)
335+
336+
def forward(self, x: torch.Tensor) -> torch.Tensor:
337+
"""Performs a forward pass through the MLP.
338+
339+
Args:
340+
x (Tensor): a tensor of shape (batch_size, input_dim)
341+
342+
Returns
343+
-------
344+
Tensor: a tensor of shape (batch_size, output_dim)
345+
"""
346+
if self.norm1 is None:
347+
core = self.act(self.mlp_core(x))
348+
gate = self.sigmoid(self.mlp_gate(x))
349+
else:
350+
core = self.act(self.norm1(self.mlp_core(x)))
351+
gate = self.sigmoid(self.norm2(self.mlp_gate(x)))
352+
return core * gate
353+
354+
355+
def find_normalization(name: str, dim: int | None = None) -> nn.Module | None:
356+
"""Return an normalization function using name."""
357+
if name is None:
358+
return None
359+
return {
360+
"batch": nn.BatchNorm1d(dim),
361+
"layer": nn.LayerNorm(dim),
362+
"none": None,
363+
}.get(name.lower(), None)
364+
365+
283366
MLP_ = make_multilayer_network(MLPLayer, nn.Module)
284367

285368

deepmd/utils/argcheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,6 +1673,18 @@ def dpa3_repflow_args() -> list[Argument]:
16731673
optional=True,
16741674
default=False,
16751675
),
1676+
Argument(
1677+
"use_gated_mlp",
1678+
bool,
1679+
optional=True,
1680+
default=False,
1681+
),
1682+
Argument(
1683+
"gated_mlp_norm",
1684+
str,
1685+
optional=True,
1686+
default="none",
1687+
),
16761688
]
16771689

16781690

0 commit comments

Comments
 (0)