|
17 | 17 | get_residual, |
18 | 18 | ) |
19 | 19 | from deepmd.pt.model.network.mlp import ( |
| 20 | + FeedForward, |
20 | 21 | MLPLayer, |
21 | 22 | ) |
22 | 23 | from deepmd.pt.model.network.utils import ( |
@@ -63,6 +64,9 @@ def __init__( |
63 | 64 | d_sel: int = 10, |
64 | 65 | d_rcut: float = 2.8, |
65 | 66 | d_rcut_smth: float = 2.0, |
| 67 | + use_ffn_node_edge_message: bool = False, |
| 68 | + use_ffn_edge_angle_message: bool = False, |
| 69 | + ffn_hidden_dim: int = 1024, |
66 | 70 | activation_function: str = "silu", |
67 | 71 | update_style: str = "res_residual", |
68 | 72 | update_residual: float = 0.1, |
@@ -119,6 +123,11 @@ def __init__( |
119 | 123 | self.d_rcut = d_rcut |
120 | 124 | self.d_rcut_smth = d_rcut_smth |
121 | 125 | self.dynamic_d_sel = (self.d_sel * 4) / self.sel_reduce_factor |
| 126 | + self.use_ffn_node_edge_message = use_ffn_node_edge_message |
| 127 | + self.use_ffn_edge_angle_message = use_ffn_edge_angle_message |
| 128 | + self.ffn_hidden_dim = ffn_hidden_dim |
| 129 | + if self.use_ffn_node_edge_message or self.use_ffn_edge_angle_message: |
| 130 | + assert not self.optim_update, "FFN does not support optim update!" |
122 | 131 |
|
123 | 132 | if self.update_dihedral: |
124 | 133 | assert self.use_dynamic_sel, "Dihedral update requires dynamic selection!" |
@@ -174,11 +183,20 @@ def __init__( |
174 | 183 | ) |
175 | 184 |
|
176 | 185 | # node edge message |
177 | | - self.node_edge_linear = MLPLayer( |
178 | | - self.edge_info_dim, |
179 | | - self.n_multi_edge_message * n_dim, |
180 | | - precision=precision, |
181 | | - seed=child_seed(seed, 4), |
| 186 | + self.node_edge_linear = ( |
| 187 | + MLPLayer( |
| 188 | + self.edge_info_dim, |
| 189 | + self.n_multi_edge_message * n_dim, |
| 190 | + precision=precision, |
| 191 | + seed=child_seed(seed, 4), |
| 192 | + ) |
| 193 | + if not self.use_ffn_node_edge_message |
| 194 | + else FeedForward( |
| 195 | + self.edge_info_dim, |
| 196 | + self.n_multi_edge_message * n_dim, |
| 197 | + self.ffn_hidden_dim, |
| 198 | + activation_function=self.activation_function, |
| 199 | + ) |
182 | 200 | ) |
183 | 201 | if self.update_style == "res_residual": |
184 | 202 | for head_index in range(self.n_multi_edge_message): |
@@ -248,11 +266,20 @@ def __init__( |
248 | 266 | self.a_compress_e_linear = None |
249 | 267 |
|
250 | 268 | # edge angle message |
251 | | - self.edge_angle_linear1 = MLPLayer( |
252 | | - self.angle_dim, |
253 | | - self.e_dim, |
254 | | - precision=precision, |
255 | | - seed=child_seed(seed, 10), |
| 269 | + self.edge_angle_linear1 = ( |
| 270 | + MLPLayer( |
| 271 | + self.angle_dim, |
| 272 | + self.e_dim, |
| 273 | + precision=precision, |
| 274 | + seed=child_seed(seed, 10), |
| 275 | + ) |
| 276 | + if not self.use_ffn_edge_angle_message |
| 277 | + else FeedForward( |
| 278 | + self.angle_dim, |
| 279 | + self.e_dim, |
| 280 | + self.ffn_hidden_dim, |
| 281 | + activation_function=self.activation_function, |
| 282 | + ) |
256 | 283 | ) |
257 | 284 | self.edge_angle_linear2 = MLPLayer( |
258 | 285 | self.e_dim, |
|
0 commit comments