|
17 | 17 | get_residual, |
18 | 18 | ) |
19 | 19 | from deepmd.pt.model.network.mlp import ( |
| 20 | + GatedMLP, |
20 | 21 | MLPLayer, |
21 | 22 | ) |
22 | 23 | from deepmd.pt.model.network.utils import ( |
@@ -59,6 +60,8 @@ def __init__( |
59 | 60 | sel_reduce_factor: float = 10.0, |
60 | 61 | smooth_edge_update: bool = False, |
61 | 62 | update_use_layernorm: bool = False, |
| 63 | + use_gated_mlp: bool = False, |
| 64 | + gated_mlp_norm: str = "none", |
62 | 65 | activation_function: str = "silu", |
63 | 66 | update_style: str = "res_residual", |
64 | 67 | update_residual: float = 0.1, |
@@ -99,6 +102,10 @@ def __init__( |
99 | 102 | self.update_residual = update_residual |
100 | 103 | self.update_residual_init = update_residual_init |
101 | 104 | self.update_use_layernorm = update_use_layernorm |
| 105 | + self.use_gated_mlp = use_gated_mlp |
| 106 | + if self.use_gated_mlp: |
| 107 | + assert not optim_update, "Gated MLP does not support optim update!" |
| 108 | + self.gated_mlp_norm = gated_mlp_norm |
102 | 109 | self.a_compress_e_rate = a_compress_e_rate |
103 | 110 | self.a_compress_use_split = a_compress_use_split |
104 | 111 | self.precision = precision |
@@ -165,13 +172,23 @@ def __init__( |
165 | 172 | ) |
166 | 173 |
|
167 | 174 | # node edge message |
168 | | - self.node_edge_linear = MLPLayer( |
169 | | - self.edge_info_dim, |
170 | | - self.n_multi_edge_message * n_dim, |
171 | | - precision=precision, |
172 | | - seed=child_seed(seed, 4), |
173 | | - trainable=trainable, |
174 | | - ) |
| 175 | + if not self.use_gated_mlp: |
| 176 | + self.node_edge_linear = MLPLayer( |
| 177 | + self.edge_info_dim, |
| 178 | + self.n_multi_edge_message * n_dim, |
| 179 | + precision=precision, |
| 180 | + seed=child_seed(seed, 4), |
| 181 | + trainable=trainable, |
| 182 | + ) |
| 183 | + else: |
| 184 | + self.node_edge_linear = GatedMLP( |
| 185 | + self.edge_info_dim, |
| 186 | + self.n_multi_edge_message * n_dim, |
| 187 | + activation_function=self.activation_function, |
| 188 | + norm=self.gated_mlp_norm, |
| 189 | + precision=precision, |
| 190 | + seed=child_seed(seed, 4), |
| 191 | + ) |
175 | 192 | if self.update_style == "res_residual": |
176 | 193 | for head_index in range(self.n_multi_edge_message): |
177 | 194 | self.n_residual.append( |
@@ -256,13 +273,23 @@ def __init__( |
256 | 273 | self.a_compress_e_linear = None |
257 | 274 |
|
258 | 275 | # edge angle message |
259 | | - self.edge_angle_linear1 = MLPLayer( |
260 | | - self.angle_dim, |
261 | | - self.e_dim, |
262 | | - precision=precision, |
263 | | - seed=child_seed(seed, 10), |
264 | | - trainable=trainable, |
265 | | - ) |
| 276 | + if not self.use_gated_mlp: |
| 277 | + self.edge_angle_linear1 = MLPLayer( |
| 278 | + self.angle_dim, |
| 279 | + self.e_dim, |
| 280 | + precision=precision, |
| 281 | + seed=child_seed(seed, 10), |
| 282 | + trainable=trainable, |
| 283 | + ) |
| 284 | + else: |
| 285 | + self.edge_angle_linear1 = GatedMLP( |
| 286 | + self.angle_dim, |
| 287 | + self.e_dim, |
| 288 | + activation_function=self.activation_function, |
| 289 | + norm=self.gated_mlp_norm, |
| 290 | + precision=precision, |
| 291 | + seed=child_seed(seed, 10), |
| 292 | + ) |
266 | 293 | self.edge_angle_linear2 = MLPLayer( |
267 | 294 | self.e_dim, |
268 | 295 | self.e_dim, |
|
0 commit comments