|
17 | 17 | get_residual, |
18 | 18 | ) |
19 | 19 | from deepmd.pt.model.network.mlp import ( |
| 20 | + GatedMLP, |
20 | 21 | MLPLayer, |
21 | 22 | ) |
22 | 23 | from deepmd.pt.model.network.utils import ( |
@@ -59,6 +60,8 @@ def __init__( |
59 | 60 | sel_reduce_factor: float = 10.0, |
60 | 61 | smooth_edge_update: bool = False, |
61 | 62 | update_use_layernorm: bool = False, |
| 63 | + use_gated_mlp: bool = False, |
| 64 | + gated_mlp_norm: str = "none", |
62 | 65 | activation_function: str = "silu", |
63 | 66 | update_style: str = "res_residual", |
64 | 67 | update_residual: float = 0.1, |
@@ -98,6 +101,10 @@ def __init__( |
98 | 101 | self.update_residual = update_residual |
99 | 102 | self.update_residual_init = update_residual_init |
100 | 103 | self.update_use_layernorm = update_use_layernorm |
| 104 | + self.use_gated_mlp = use_gated_mlp |
| 105 | + if self.use_gated_mlp: |
| 106 | + assert not optim_update, "Gated MLP does not support optim update!" |
| 107 | + self.gated_mlp_norm = gated_mlp_norm |
101 | 108 | self.a_compress_e_rate = a_compress_e_rate |
102 | 109 | self.a_compress_use_split = a_compress_use_split |
103 | 110 | self.precision = precision |
@@ -160,12 +167,22 @@ def __init__( |
160 | 167 | ) |
161 | 168 |
|
162 | 169 | # node edge message |
163 | | - self.node_edge_linear = MLPLayer( |
164 | | - self.edge_info_dim, |
165 | | - self.n_multi_edge_message * n_dim, |
166 | | - precision=precision, |
167 | | - seed=child_seed(seed, 4), |
168 | | - ) |
| 170 | + if not self.use_gated_mlp: |
| 171 | + self.node_edge_linear = MLPLayer( |
| 172 | + self.edge_info_dim, |
| 173 | + self.n_multi_edge_message * n_dim, |
| 174 | + precision=precision, |
| 175 | + seed=child_seed(seed, 4), |
| 176 | + ) |
| 177 | + else: |
| 178 | + self.node_edge_linear = GatedMLP( |
| 179 | + self.edge_info_dim, |
| 180 | + self.n_multi_edge_message * n_dim, |
| 181 | + activation_function=self.activation_function, |
| 182 | + norm=self.gated_mlp_norm, |
| 183 | + precision=precision, |
| 184 | + seed=child_seed(seed, 4), |
| 185 | + ) |
169 | 186 | if self.update_style == "res_residual": |
170 | 187 | for head_index in range(self.n_multi_edge_message): |
171 | 188 | self.n_residual.append( |
@@ -245,12 +262,22 @@ def __init__( |
245 | 262 | self.a_compress_e_linear = None |
246 | 263 |
|
247 | 264 | # edge angle message |
248 | | - self.edge_angle_linear1 = MLPLayer( |
249 | | - self.angle_dim, |
250 | | - self.e_dim, |
251 | | - precision=precision, |
252 | | - seed=child_seed(seed, 10), |
253 | | - ) |
| 265 | + if not self.use_gated_mlp: |
| 266 | + self.edge_angle_linear1 = MLPLayer( |
| 267 | + self.angle_dim, |
| 268 | + self.e_dim, |
| 269 | + precision=precision, |
| 270 | + seed=child_seed(seed, 10), |
| 271 | + ) |
| 272 | + else: |
| 273 | + self.edge_angle_linear1 = GatedMLP( |
| 274 | + self.angle_dim, |
| 275 | + self.e_dim, |
| 276 | + activation_function=self.activation_function, |
| 277 | + norm=self.gated_mlp_norm, |
| 278 | + precision=precision, |
| 279 | + seed=child_seed(seed, 10), |
| 280 | + ) |
254 | 281 | self.edge_angle_linear2 = MLPLayer( |
255 | 282 | self.e_dim, |
256 | 283 | self.e_dim, |
|
0 commit comments