Skip to content

Commit 005fecf

Browse files
committed
add e3nn_conv_weights_use_tebd
1 parent 126259e commit 005fecf

4 files changed

Lines changed: 19 additions & 2 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
e3nn_conv_use_edge_sh_feat: bool = False,
101101
edge_sh_feat_use_rbf_weights: bool = False,
102102
e3nn_conv_use_vi: bool = False,
103+
e3nn_conv_weights_use_tebd: bool = False,
103104
) -> None:
104105
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
105106
@@ -251,6 +252,7 @@ def __init__(
251252
self.e3nn_conv_use_edge_sh_feat = e3nn_conv_use_edge_sh_feat
252253
self.edge_sh_feat_use_rbf_weights = edge_sh_feat_use_rbf_weights
253254
self.e3nn_conv_use_vi = e3nn_conv_use_vi
255+
self.e3nn_conv_weights_use_tebd = e3nn_conv_weights_use_tebd
254256

255257
def __getitem__(self, key):
256258
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def init_subclass_params(sub_data, sub_class):
233233
e3nn_conv_use_edge_sh_feat=self.repflow_args.e3nn_conv_use_edge_sh_feat,
234234
edge_sh_feat_use_rbf_weights=self.repflow_args.edge_sh_feat_use_rbf_weights,
235235
e3nn_conv_use_vi=self.repflow_args.e3nn_conv_use_vi,
236+
e3nn_conv_weights_use_tebd=self.repflow_args.e3nn_conv_weights_use_tebd,
236237
use_e3nn_angle_conv=self.repflow_args.use_e3nn_angle_conv,
237238
e3nn_angle_conv_l_max=self.repflow_args.e3nn_angle_conv_l_max,
238239
e3nn_use_edge_feat_weights=self.repflow_args.e3nn_use_edge_feat_weights,

deepmd/pt/model/descriptor/repflows.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def __init__(
187187
e3nn_conv_use_edge_sh_feat: bool = False,
188188
edge_sh_feat_use_rbf_weights: bool = False,
189189
e3nn_conv_use_vi: bool = False,
190+
e3nn_conv_weights_use_tebd: bool = False,
190191
e3nn_conv_l_max: int = 3,
191192
e3nn_use_edge_feat_weights: bool = False,
192193
use_e3nn_angle_conv: bool = False,
@@ -493,6 +494,7 @@ def __init__(
493494
self.use_e3nn_denominator = use_e3nn_denominator
494495
self.e3nn_conv_use_edge_sh_feat = e3nn_conv_use_edge_sh_feat
495496
self.e3nn_conv_use_vi = e3nn_conv_use_vi
497+
self.e3nn_conv_weights_use_tebd = e3nn_conv_weights_use_tebd
496498
if self.e3nn_conv_use_vi:
497499
assert e3nn_conv_use_edge_sh_feat, "e3nn_conv_use_edge_sh_feat must be True when e3nn_conv_use_vi is True"
498500
self.e3nn_conv_l_max = e3nn_conv_l_max
@@ -651,7 +653,7 @@ def __init__(
651653
"irreps_out": irreps_out,
652654
"denominator": 1.0 if not self.use_e3nn_denominator else self.dynamic_e_sel / 4,
653655
"train_denominator": True,
654-
"weight_layer_input_to_hidden": [8, 64, 64] if not self.e3nn_use_edge_feat_weights else [self.e_dim, 64, 64],
656+
"weight_layer_input_to_hidden": [8 if not self.e3nn_conv_weights_use_tebd else 8 + 2*self.n_dim, 64, 64] if not self.e3nn_use_edge_feat_weights else [self.e_dim, 64, 64],
655657
}
656658
irreps_x = irreps_out
657659

@@ -1284,7 +1286,13 @@ def forward(
12841286
assert edge_dist is not None
12851287
# n_edge x rbf
12861288
edge_env = self.edge_env(edge_dist/self.e_rcut)
1287-
edge_rbf_ebd = self.edge_rbf_embed(edge_dist) * edge_env
1289+
if not self.e3nn_conv_weights_use_tebd:
1290+
edge_rbf_ebd = self.edge_rbf_embed(edge_dist) * edge_env
1291+
else:
1292+
edge_src = edge_index[:, 1]
1293+
edge_dst = edge_index[:, 0]
1294+
atype_embd_reshape = atype_embd.reshape(nframes * nloc, -1)
1295+
edge_rbf_ebd = torch.cat([self.edge_rbf_embed(edge_dist), atype_embd_reshape[edge_dst], atype_embd_reshape[edge_src]], dim=-1) * edge_env
12881296
# n_edge x num_sph(16)
12891297
edge_sph = self.edge_spherical_embd(diff)
12901298
node_sph_embed = node_ebd

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,6 +2013,12 @@ def dpa3_repflow_args():
20132013
optional=True,
20142014
default=False,
20152015
),
2016+
Argument(
2017+
"e3nn_conv_weights_use_tebd",
2018+
bool,
2019+
optional=True,
2020+
default=False,
2021+
),
20162022
]
20172023

20182024

0 commit comments

Comments
 (0)