Skip to content

Commit da1dc99

Browse files
committed
feat: add ffn for ne and ea
1 parent e3f65be commit da1dc99

6 files changed

Lines changed: 153 additions & 10 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def __init__(
4141
d_sel: int = 10,
4242
d_rcut: float = 2.8,
4343
d_rcut_smth: float = 2.0,
44+
use_ffn_node_edge_message: bool = False,
45+
use_ffn_edge_angle_message: bool = False,
46+
ffn_hidden_dim: int = 1024,
4447
) -> None:
4548
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
4649
@@ -131,6 +134,9 @@ def __init__(
131134
self.d_sel = d_sel
132135
self.d_rcut = d_rcut
133136
self.d_rcut_smth = d_rcut_smth
137+
self.use_ffn_node_edge_message = use_ffn_node_edge_message
138+
self.use_ffn_edge_angle_message = use_ffn_edge_angle_message
139+
self.ffn_hidden_dim = ffn_hidden_dim
134140

135141
def __getitem__(self, key):
136142
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ def init_subclass_params(sub_data, sub_class):
176176
d_sel=self.repflow_args.d_sel,
177177
d_rcut=self.repflow_args.d_rcut,
178178
d_rcut_smth=self.repflow_args.d_rcut_smth,
179+
use_ffn_node_edge_message=self.repflow_args.use_ffn_node_edge_message,
180+
use_ffn_edge_angle_message=self.repflow_args.use_ffn_edge_angle_message,
181+
ffn_hidden_dim=self.repflow_args.ffn_hidden_dim,
179182
exclude_types=exclude_types,
180183
env_protection=env_protection,
181184
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_residual,
1818
)
1919
from deepmd.pt.model.network.mlp import (
20+
FeedForward,
2021
MLPLayer,
2122
)
2223
from deepmd.pt.model.network.utils import (
@@ -63,6 +64,9 @@ def __init__(
6364
d_sel: int = 10,
6465
d_rcut: float = 2.8,
6566
d_rcut_smth: float = 2.0,
67+
use_ffn_node_edge_message: bool = False,
68+
use_ffn_edge_angle_message: bool = False,
69+
ffn_hidden_dim: int = 1024,
6670
activation_function: str = "silu",
6771
update_style: str = "res_residual",
6872
update_residual: float = 0.1,
@@ -119,6 +123,11 @@ def __init__(
119123
self.d_rcut = d_rcut
120124
self.d_rcut_smth = d_rcut_smth
121125
self.dynamic_d_sel = (self.d_sel * 4) / self.sel_reduce_factor
126+
self.use_ffn_node_edge_message = use_ffn_node_edge_message
127+
self.use_ffn_edge_angle_message = use_ffn_edge_angle_message
128+
self.ffn_hidden_dim = ffn_hidden_dim
129+
if self.use_ffn_node_edge_message or self.use_ffn_edge_angle_message:
130+
assert not self.optim_update, "FFN does not support optim update!"
122131

123132
if self.update_dihedral:
124133
assert self.use_dynamic_sel, "Dihedral update requires dynamic selection!"
@@ -174,11 +183,20 @@ def __init__(
174183
)
175184

176185
# node edge message
177-
self.node_edge_linear = MLPLayer(
178-
self.edge_info_dim,
179-
self.n_multi_edge_message * n_dim,
180-
precision=precision,
181-
seed=child_seed(seed, 4),
186+
self.node_edge_linear = (
187+
MLPLayer(
188+
self.edge_info_dim,
189+
self.n_multi_edge_message * n_dim,
190+
precision=precision,
191+
seed=child_seed(seed, 4),
192+
)
193+
if not self.use_ffn_node_edge_message
194+
else FeedForward(
195+
self.edge_info_dim,
196+
self.n_multi_edge_message * n_dim,
197+
self.ffn_hidden_dim,
198+
activation_function=self.activation_function,
199+
)
182200
)
183201
if self.update_style == "res_residual":
184202
for head_index in range(self.n_multi_edge_message):
@@ -248,11 +266,20 @@ def __init__(
248266
self.a_compress_e_linear = None
249267

250268
# edge angle message
251-
self.edge_angle_linear1 = MLPLayer(
252-
self.angle_dim,
253-
self.e_dim,
254-
precision=precision,
255-
seed=child_seed(seed, 10),
269+
self.edge_angle_linear1 = (
270+
MLPLayer(
271+
self.angle_dim,
272+
self.e_dim,
273+
precision=precision,
274+
seed=child_seed(seed, 10),
275+
)
276+
if not self.use_ffn_edge_angle_message
277+
else FeedForward(
278+
self.angle_dim,
279+
self.e_dim,
280+
self.ffn_hidden_dim,
281+
activation_function=self.activation_function,
282+
)
256283
)
257284
self.edge_angle_linear2 = MLPLayer(
258285
self.e_dim,

deepmd/pt/model/descriptor/repflows.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def __init__(
116116
d_sel: int = 10,
117117
d_rcut: float = 2.8,
118118
d_rcut_smth: float = 2.0,
119+
use_ffn_node_edge_message: bool = False,
120+
use_ffn_edge_angle_message: bool = False,
121+
ffn_hidden_dim: int = 1024,
119122
optim_update: bool = True,
120123
seed: Optional[Union[int, list[int]]] = None,
121124
) -> None:
@@ -248,6 +251,9 @@ def __init__(
248251
self.d_sel = d_sel
249252
self.d_rcut = d_rcut
250253
self.d_rcut_smth = d_rcut_smth
254+
self.use_ffn_node_edge_message = use_ffn_node_edge_message
255+
self.use_ffn_edge_angle_message = use_ffn_edge_angle_message
256+
self.ffn_hidden_dim = ffn_hidden_dim
251257

252258
self.n_dim = n_dim
253259
self.e_dim = e_dim
@@ -321,6 +327,9 @@ def __init__(
321327
d_sel=self.d_sel,
322328
d_rcut=self.d_rcut,
323329
d_rcut_smth=self.d_rcut_smth,
330+
use_ffn_node_edge_message=self.use_ffn_node_edge_message,
331+
use_ffn_edge_angle_message=self.use_ffn_edge_angle_message,
332+
ffn_hidden_dim=self.ffn_hidden_dim,
324333
seed=child_seed(child_seed(seed, 1), ii),
325334
)
326335
)

deepmd/pt/model/network/mlp.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
to_numpy_array,
4141
to_torch_tensor,
4242
)
43+
from deepmd.utils.version import (
44+
check_version_compatibility,
45+
)
4346

4447

4548
def empty_t(shape, precision):
@@ -278,6 +281,83 @@ def check_load_param(ss):
278281
return obj
279282

280283

284+
class FeedForward(nn.Module):
285+
"""
286+
A feed forward network with two linear layers and an activation function.
287+
No dropout, no gate and no residual connection.
288+
"""
289+
290+
def __init__(
291+
self,
292+
num_in: int,
293+
num_out: int,
294+
hidden_dim: int,
295+
activation_function: Optional[str] = None,
296+
bias: bool = False,
297+
) -> None:
298+
super().__init__()
299+
self.num_in = num_in
300+
self.num_out = num_out
301+
self.hidden_dim = hidden_dim
302+
self.activation_function = activation_function
303+
self.bias = bias
304+
self.w1 = MLPLayer(
305+
num_in=num_in,
306+
num_out=hidden_dim,
307+
bias=bias,
308+
)
309+
self.act = ActivationFn(activation_function)
310+
self.w2 = MLPLayer(
311+
num_in=hidden_dim,
312+
num_out=num_out,
313+
bias=bias,
314+
)
315+
316+
def forward(self, x: torch.Tensor) -> torch.Tensor:
317+
return self.w2(self.act(self.w1(x)))
318+
319+
def serialize(self) -> dict:
320+
"""Serialize the networks to a dict.
321+
322+
Returns
323+
-------
324+
dict
325+
The serialized networks.
326+
"""
327+
data = {
328+
"@class": "FeedForward",
329+
"@version": 1,
330+
"num_in": self.num_in,
331+
"num_out": self.num_out,
332+
"hidden_dim": self.hidden_dim,
333+
"activation_function": self.activation_function,
334+
"bias": self.bias,
335+
"w1": self.w1.serialize(),
336+
"w2": self.w2.serialize(),
337+
}
338+
return data
339+
340+
@classmethod
341+
def deserialize(cls, data: dict) -> "FeedForward":
342+
"""Deserialize the networks from a dict.
343+
344+
Parameters
345+
----------
346+
data : dict
347+
The dict to deserialize from.
348+
"""
349+
data = data.copy()
350+
check_version_compatibility(data.pop("@version"), 1, 1)
351+
data.pop("@class")
352+
w1 = data.pop("w1")
353+
w2 = data.pop("w2")
354+
355+
obj = cls(**data)
356+
obj.w1 = MLPLayer.deserialize(w1)
357+
obj.w2 = MLPLayer.deserialize(w2)
358+
return obj
359+
360+
281361
MLP_ = make_multilayer_network(MLPLayer, nn.Module)
282362

283363

deepmd/utils/argcheck.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,6 +1646,24 @@ def dpa3_repflow_args():
16461646
optional=True,
16471647
default=2.0,
16481648
),
1649+
Argument(
1650+
"use_ffn_node_edge_message",
1651+
bool,
1652+
optional=True,
1653+
default=False,
1654+
),
1655+
Argument(
1656+
"use_ffn_edge_angle_message",
1657+
bool,
1658+
optional=True,
1659+
default=False,
1660+
),
1661+
Argument(
1662+
"ffn_hidden_dim",
1663+
int,
1664+
optional=True,
1665+
default=1024,
1666+
),
16491667
]
16501668

16511669

0 commit comments

Comments
 (0)