Skip to content

Commit 603b1b0

Browse files
committed
add e3nn_use_edge_feat_weights
1 parent 2654a75 commit 603b1b0

6 files changed

Lines changed: 19 additions & 1 deletion

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
e3nn_conv_pattern: str = "128x0e+64x1e+32x2e+32x3e",
9393
use_e3nn_denominator: bool = False,
9494
e3nn_conv_l_max: int = 3,
95+
e3nn_use_edge_feat_weights: bool = False,
9596
) -> None:
9697
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
9798
@@ -235,6 +236,7 @@ def __init__(
235236
self.e3nn_conv_pattern = e3nn_conv_pattern
236237
self.use_e3nn_denominator = use_e3nn_denominator
237238
self.e3nn_conv_l_max = e3nn_conv_l_max
239+
self.e3nn_use_edge_feat_weights = e3nn_use_edge_feat_weights
238240

239241
def __getitem__(self, key):
240242
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def init_subclass_params(sub_data, sub_class):
230230
use_e3nn_conv=self.repflow_args.use_e3nn_conv,
231231
e3nn_conv_pattern=self.repflow_args.e3nn_conv_pattern,
232232
use_e3nn_denominator=self.repflow_args.use_e3nn_denominator,
233+
e3nn_use_edge_feat_weights=self.repflow_args.e3nn_use_edge_feat_weights,
233234
e3nn_conv_l_max=self.repflow_args.e3nn_conv_l_max,
234235
exclude_types=exclude_types,
235236
env_protection=env_protection,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def __init__(
104104
dropout_rate: float = 0.1,
105105
use_e3nn_conv: bool = False,
106106
e3nn_conv_pattern: str = "128x0e+64x1e+32x2e+32x3e",
107+
e3nn_use_edge_feat_weights: bool = False,
107108
e3nn_conv_args: dict = {},
108109
activation_function: str = "silu",
109110
update_style: str = "res_residual",
@@ -387,6 +388,7 @@ def __init__(
387388
self.use_e3nn_conv = use_e3nn_conv
388389
self.e3nn_conv_pattern = e3nn_conv_pattern
389390
self.e3nn_conv_args = e3nn_conv_args
391+
self.e3nn_use_edge_feat_weights = e3nn_use_edge_feat_weights
390392
if self.use_e3nn_conv:
391393
self.e3nn_conv_block = IrrepsBlock(**self.e3nn_conv_args, weight_layer_act=self.activation_function)
392394
if self.update_style == "res_residual":
@@ -1450,7 +1452,8 @@ def forward(
14501452
assert edge_sph is not None
14511453
assert edge_rbf_ebd is not None
14521454
assert edge_index is not None
1453-
node_sph_embed = self.e3nn_conv_block(node_sph_embed, edge_sph, edge_rbf_ebd, edge_index)
1455+
edge_weights = edge_rbf_ebd if not self.e3nn_use_edge_feat_weights else edge_ebd
1456+
node_sph_embed = self.e3nn_conv_block(node_sph_embed, edge_sph, edge_weights, edge_index)
14541457
# node_sph_embed = node_sph_embed
14551458
sph_conv_update = node_sph_embed[:, :, :self.n_dim].clone() # avoid following in-place op
14561459
n_update_list.append(sph_conv_update)

deepmd/pt/model/descriptor/repflows.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def __init__(
184184
e3nn_conv_pattern: str = "128x0e+64x1e+32x2e+32x3e",
185185
use_e3nn_denominator: bool = False,
186186
e3nn_conv_l_max: int = 3,
187+
e3nn_use_edge_feat_weights: bool = False,
187188
seed: Optional[Union[int, list[int]]] = None,
188189
) -> None:
189190
r"""
@@ -482,6 +483,7 @@ def __init__(
482483
self.e3nn_conv_pattern = e3nn_conv_pattern
483484
self.use_e3nn_denominator = use_e3nn_denominator
484485
self.e3nn_conv_l_max = e3nn_conv_l_max
486+
self.e3nn_use_edge_feat_weights = e3nn_use_edge_feat_weights
485487

486488
if not self.edge_use_esen_rbf:
487489
self.edge_embd = MLPLayer(
@@ -568,6 +570,7 @@ def __init__(
568570
"irreps_out": irreps_out,
569571
"denominator": 1.0 if not self.use_e3nn_denominator else self.dynamic_e_sel / 4,
570572
"train_denominator": True,
573+
"weight_layer_input_to_hidden": [8, 64, 64] if not self.e3nn_use_edge_feat_weights else [self.e_dim],
571574
}
572575
irreps_x = irreps_out
573576
layers.append(
@@ -633,6 +636,7 @@ def __init__(
633636
dropout_rate=self.dropout_rate,
634637
use_e3nn_conv=self.use_e3nn_conv,
635638
e3nn_conv_pattern=self.e3nn_conv_pattern,
639+
e3nn_use_edge_feat_weights=self.e3nn_use_edge_feat_weights,
636640
e3nn_conv_args=e3nn_conv_args,
637641
seed=child_seed(child_seed(seed, 1), ii),
638642
)

deepmd/pt/model/network/e3nn_networks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def __init__(
245245
weight_layer_act="tanh",
246246
denominator: float = 1.0,
247247
train_denominator: bool = False,
248+
weight_layer_input_to_hidden: list[int] = [8, 64, 64],
248249
) -> None:
249250
super().__init__()
250251

@@ -256,6 +257,7 @@ def __init__(
256257
weight_layer_act=weight_layer_act,
257258
denominator=denominator,
258259
train_denominator=train_denominator,
260+
weight_layer_input_to_hidden=weight_layer_input_to_hidden,
259261
)
260262

261263
# 3. gate

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,6 +1965,12 @@ def dpa3_repflow_args():
19651965
optional=True,
19661966
default=3,
19671967
),
1968+
Argument(
1969+
"e3nn_use_edge_feat_weights",
1970+
bool,
1971+
optional=True,
1972+
default=False,
1973+
),
19681974
]
19691975

19701976

0 commit comments

Comments
 (0)