Skip to content

Commit 7c2287e

Browse files
committed
add GatedMLP
1 parent 82286fd commit 7c2287e

6 files changed

Lines changed: 149 additions & 15 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def __init__(
178178
use_dynamic_sel: bool = False,
179179
sel_reduce_factor: float = 10.0,
180180
update_use_layernorm: bool = False,
181+
use_gated_mlp: bool = False,
182+
gated_mlp_norm: str = "none",
181183
) -> None:
182184
self.n_dim = n_dim
183185
self.e_dim = e_dim
@@ -209,6 +211,8 @@ def __init__(
209211
self.use_dynamic_sel = use_dynamic_sel
210212
self.sel_reduce_factor = sel_reduce_factor
211213
self.update_use_layernorm = update_use_layernorm
214+
self.use_gated_mlp = use_gated_mlp
215+
self.gated_mlp_norm = gated_mlp_norm
212216

213217
def __getitem__(self, key):
214218
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,10 @@ def init_subclass_params(sub_data, sub_class):
167167
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
168168
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
169169
use_loc_mapping=use_loc_mapping,
170+
# followings are new added param
170171
update_use_layernorm=self.repflow_args.update_use_layernorm,
172+
use_gated_mlp=self.repflow_args.use_gated_mlp,
173+
gated_mlp_norm=self.repflow_args.gated_mlp_norm,
171174
exclude_types=exclude_types,
172175
env_protection=env_protection,
173176
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_residual,
1818
)
1919
from deepmd.pt.model.network.mlp import (
20+
GatedMLP,
2021
MLPLayer,
2122
)
2223
from deepmd.pt.model.network.utils import (
@@ -59,6 +60,8 @@ def __init__(
5960
sel_reduce_factor: float = 10.0,
6061
smooth_edge_update: bool = False,
6162
update_use_layernorm: bool = False,
63+
use_gated_mlp: bool = False,
64+
gated_mlp_norm: str = "none",
6265
activation_function: str = "silu",
6366
update_style: str = "res_residual",
6467
update_residual: float = 0.1,
@@ -98,6 +101,10 @@ def __init__(
98101
self.update_residual = update_residual
99102
self.update_residual_init = update_residual_init
100103
self.update_use_layernorm = update_use_layernorm
104+
self.use_gated_mlp = use_gated_mlp
105+
if self.use_gated_mlp:
106+
assert not optim_update, "Gated MLP does not support optim update!"
107+
self.gated_mlp_norm = gated_mlp_norm
101108
self.a_compress_e_rate = a_compress_e_rate
102109
self.a_compress_use_split = a_compress_use_split
103110
self.precision = precision
@@ -160,12 +167,22 @@ def __init__(
160167
)
161168

162169
# node edge message
163-
self.node_edge_linear = MLPLayer(
164-
self.edge_info_dim,
165-
self.n_multi_edge_message * n_dim,
166-
precision=precision,
167-
seed=child_seed(seed, 4),
168-
)
170+
if not self.use_gated_mlp:
171+
self.node_edge_linear = MLPLayer(
172+
self.edge_info_dim,
173+
self.n_multi_edge_message * n_dim,
174+
precision=precision,
175+
seed=child_seed(seed, 4),
176+
)
177+
else:
178+
self.node_edge_linear = GatedMLP(
179+
self.edge_info_dim,
180+
self.n_multi_edge_message * n_dim,
181+
activation_function=self.activation_function,
182+
norm=self.gated_mlp_norm,
183+
precision=precision,
184+
seed=child_seed(seed, 4),
185+
)
169186
if self.update_style == "res_residual":
170187
for head_index in range(self.n_multi_edge_message):
171188
self.n_residual.append(
@@ -245,12 +262,22 @@ def __init__(
245262
self.a_compress_e_linear = None
246263

247264
# edge angle message
248-
self.edge_angle_linear1 = MLPLayer(
249-
self.angle_dim,
250-
self.e_dim,
251-
precision=precision,
252-
seed=child_seed(seed, 10),
253-
)
265+
if not self.use_gated_mlp:
266+
self.edge_angle_linear1 = MLPLayer(
267+
self.angle_dim,
268+
self.e_dim,
269+
precision=precision,
270+
seed=child_seed(seed, 10),
271+
)
272+
else:
273+
self.edge_angle_linear1 = GatedMLP(
274+
self.angle_dim,
275+
self.e_dim,
276+
activation_function=self.activation_function,
277+
norm=self.gated_mlp_norm,
278+
precision=precision,
279+
seed=child_seed(seed, 10),
280+
)
254281
self.edge_angle_linear2 = MLPLayer(
255282
self.e_dim,
256283
self.e_dim,

deepmd/pt/model/descriptor/repflows.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ def __init__(
218218
sel_reduce_factor: float = 10.0,
219219
use_loc_mapping: bool = True,
220220
update_use_layernorm: bool = False,
221+
use_gated_mlp: bool = False,
222+
gated_mlp_norm: str = "none",
221223
optim_update: bool = True,
222224
seed: Optional[Union[int, list[int]]] = None,
223225
) -> None:
@@ -285,6 +287,8 @@ def __init__(
285287
self.epsilon = 1e-4
286288
self.seed = seed
287289
self.update_use_layernorm = update_use_layernorm
290+
self.use_gated_mlp = use_gated_mlp
291+
self.gated_mlp_norm = gated_mlp_norm
288292

289293
self.edge_embd = MLPLayer(
290294
1, self.e_dim, precision=precision, seed=child_seed(seed, 0)
@@ -322,6 +326,8 @@ def __init__(
322326
sel_reduce_factor=self.sel_reduce_factor,
323327
smooth_edge_update=self.smooth_edge_update,
324328
update_use_layernorm=self.update_use_layernorm,
329+
use_gated_mlp=self.use_gated_mlp,
330+
gated_mlp_norm=self.gated_mlp_norm,
325331
seed=child_seed(child_seed(seed, 1), ii),
326332
)
327333
)
@@ -336,7 +342,7 @@ def __init__(
336342
self.register_buffer("mean", mean)
337343
self.register_buffer("stddev", stddev)
338344
self.stats = None
339-
345+
340346
additional_output_for_fitting: dict[str, Optional[torch.Tensor]]
341347

342348
def get_rcut(self) -> float:
@@ -370,7 +376,7 @@ def get_dim_in(self) -> int:
370376
def get_dim_emb(self) -> int:
371377
"""Returns the embedding dimension e_dim."""
372378
return self.e_dim
373-
379+
374380
def get_additional_output_for_fitting(self):
375381
return self.additional_output_for_fitting
376382

@@ -381,7 +387,6 @@ def get_norm_fact(self) -> list[float]:
381387
# float(self.dynamic_a_sel if self.use_dynamic_sel else self.a_sel),
382388
]
383389

384-
385390
def __setitem__(self, key, value) -> None:
386391
if key in ("avg", "data_avg", "davg"):
387392
self.mean = value

deepmd/pt/model/network/mlp.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,89 @@ def check_load_param(ss):
275275
return obj
276276

277277

278+
class GatedMLP(nn.Module):
279+
"""Gated MLP
280+
similar model structure is used in CGCNN and M3GNet.
281+
"""
282+
283+
def __init__(
284+
self,
285+
input_dim: int,
286+
output_dim: int,
287+
*,
288+
activation_function: Optional[str] = None,
289+
norm: str = "batch",
290+
bias: bool = True,
291+
precision: str = DEFAULT_PRECISION,
292+
seed: Optional[Union[int, list[int]]] = None,
293+
) -> None:
294+
"""Initialize a gated MLP.
295+
296+
Args:
297+
input_dim (int): the input dimension
298+
output_dim (int): the output dimension
299+
activation_function (str, optional): The name of the activation function to use in
300+
the gated MLP. Must be one of "relu", "silu", "tanh", or "gelu".
301+
Default = "silu"
302+
norm (str, optional): The name of the normalization layer to use on the
303+
updated atom features. Must be one of "batch", "layer", or None.
304+
Default = "batch"
305+
bias (bool): whether to use bias in each Linear layers.
306+
Default = True
307+
"""
308+
super().__init__()
309+
self.mlp_core = MLPLayer(
310+
input_dim,
311+
output_dim,
312+
bias=bias,
313+
precision=precision,
314+
seed=seed,
315+
)
316+
self.mlp_gate = MLPLayer(
317+
input_dim,
318+
output_dim,
319+
bias=bias,
320+
precision=precision,
321+
seed=seed,
322+
)
323+
# for jit
324+
self.matrix = self.mlp_core.matrix
325+
self.bias = self.mlp_core.bias
326+
self.act = ActivationFn(activation_function)
327+
self.sigmoid = nn.Sigmoid()
328+
self.norm1 = find_normalization(name=norm, dim=output_dim)
329+
self.norm2 = find_normalization(name=norm, dim=output_dim)
330+
331+
def forward(self, x: torch.Tensor) -> torch.Tensor:
332+
"""Performs a forward pass through the MLP.
333+
334+
Args:
335+
x (Tensor): a tensor of shape (batch_size, input_dim)
336+
337+
Returns
338+
-------
339+
Tensor: a tensor of shape (batch_size, output_dim)
340+
"""
341+
if self.norm1 is None:
342+
core = self.act(self.mlp_core(x))
343+
gate = self.sigmoid(self.mlp_gate(x))
344+
else:
345+
core = self.act(self.norm1(self.mlp_core(x)))
346+
gate = self.sigmoid(self.norm2(self.mlp_gate(x)))
347+
return core * gate
348+
349+
350+
def find_normalization(name: str, dim: int | None = None) -> nn.Module | None:
351+
"""Return an normalization function using name."""
352+
if name is None:
353+
return None
354+
return {
355+
"batch": nn.BatchNorm1d(dim),
356+
"layer": nn.LayerNorm(dim),
357+
"none": None,
358+
}.get(name.lower(), None)
359+
360+
278361
MLP_ = make_multilayer_network(MLPLayer, nn.Module)
279362

280363

deepmd/utils/argcheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,6 +1677,18 @@ def dpa3_repflow_args():
16771677
optional=True,
16781678
default=False,
16791679
),
1680+
Argument(
1681+
"use_gated_mlp",
1682+
bool,
1683+
optional=True,
1684+
default=False,
1685+
),
1686+
Argument(
1687+
"gated_mlp_norm",
1688+
str,
1689+
optional=True,
1690+
default="none",
1691+
),
16801692
]
16811693

16821694

0 commit comments

Comments
 (0)