Skip to content

Commit 1722ec1

Browse files
committed
feat: add concat rbf
1 parent 89aec09 commit 1722ec1

4 files changed

Lines changed: 21 additions & 2 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
use_ffn_edge_angle_message: bool = False,
4747
use_ffn_angle_angle_message: bool = False,
4848
ffn_hidden_dim: int = 1024,
49+
edge_use_concat_rbf: bool = False,
4950
edge_use_rbf: bool = False,
5051
edge_use_dist: bool = False,
5152
embed_use_bias: bool = True,
@@ -144,6 +145,7 @@ def __init__(
144145
self.use_ffn_edge_angle_message = use_ffn_edge_angle_message
145146
self.use_ffn_angle_angle_message = use_ffn_angle_angle_message
146147
self.ffn_hidden_dim = ffn_hidden_dim
148+
self.edge_use_concat_rbf = edge_use_concat_rbf
147149
self.edge_use_rbf = edge_use_rbf
148150
self.edge_use_dist = edge_use_dist
149151
self.embed_use_bias = embed_use_bias

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def init_subclass_params(sub_data, sub_class):
181181
use_ffn_edge_angle_message=self.repflow_args.use_ffn_edge_angle_message,
182182
use_ffn_angle_angle_message=self.repflow_args.use_ffn_angle_angle_message,
183183
ffn_hidden_dim=self.repflow_args.ffn_hidden_dim,
184+
edge_use_concat_rbf=self.repflow_args.edge_use_concat_rbf,
184185
edge_use_rbf=self.repflow_args.edge_use_rbf,
185186
edge_use_dist=self.repflow_args.edge_use_dist,
186187
embed_use_bias=self.repflow_args.embed_use_bias,

deepmd/pt/model/descriptor/repflows.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
use_ffn_edge_angle_message: bool = False,
123123
use_ffn_angle_angle_message: bool = False,
124124
ffn_hidden_dim: int = 1024,
125+
edge_use_concat_rbf: bool = False,
125126
edge_use_rbf: bool = False,
126127
edge_use_dist: bool = False,
127128
embed_use_bias: bool = True,
@@ -262,11 +263,15 @@ def __init__(
262263
self.use_ffn_edge_angle_message = use_ffn_edge_angle_message
263264
self.use_ffn_angle_angle_message = use_ffn_angle_angle_message
264265
self.ffn_hidden_dim = ffn_hidden_dim
266+
self.edge_use_concat_rbf = edge_use_concat_rbf
265267
self.edge_use_rbf = edge_use_rbf
266268
self.edge_use_dist = edge_use_dist
267269
self.embed_use_bias = embed_use_bias
268270
self.edge_embed_input_dim = 1
269-
if self.edge_use_rbf:
271+
if self.edge_use_concat_rbf:
272+
self.rbf = BesselBasis(self.e_rcut)
273+
self.edge_embed_input_dim = 1 + self.rbf.num_basis
274+
elif self.edge_use_rbf:
270275
self.rbf = BesselBasis(self.e_rcut)
271276
self.edge_embed_input_dim = self.rbf.num_basis
272277
else:
@@ -536,7 +541,7 @@ def forward(
536541
# get edge and angle embedding input
537542
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
538543
edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1)
539-
if self.edge_use_rbf or self.edge_use_dist:
544+
if self.edge_use_concat_rbf or self.edge_use_rbf or self.edge_use_dist:
540545
# nb x nloc x nnei x 1
541546
edge_input = torch.linalg.norm(diff, dim=-1, keepdim=True)
542547
# nf x nloc x a_nnei x 3
@@ -668,6 +673,11 @@ def forward(
668673
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
669674
if self.edge_use_dist:
670675
edge_ebd = self.edge_embd(edge_input)
676+
elif self.edge_use_concat_rbf:
677+
assert self.rbf is not None
678+
edge_ebd = self.edge_embd(
679+
torch.cat([dmatrix[..., :1], self.rbf(edge_input)], dim=-1)
680+
)
671681
elif self.edge_use_rbf:
672682
assert self.rbf is not None
673683
edge_ebd = self.edge_embd(self.rbf(edge_input))

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,6 +1676,12 @@ def dpa3_repflow_args():
16761676
optional=True,
16771677
default=1024,
16781678
),
1679+
Argument(
1680+
"edge_use_concat_rbf",
1681+
bool,
1682+
optional=True,
1683+
default=False,
1684+
),
16791685
Argument(
16801686
"edge_use_rbf",
16811687
bool,

0 commit comments

Comments
 (0)