Skip to content

Commit 978a5da

Browse files
committed
add e3nn_conv_use_edge_sh_feat
1 parent eec08fa commit 978a5da

6 files changed

Lines changed: 243 additions & 30 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
e3nn_angle_conv_l_max: int = 2,
9898
e3nn_angle_use_cross: bool = False,
9999
e3nn_angle_only_single_angle: bool = False,
100+
e3nn_conv_use_edge_sh_feat: bool = False,
100101
) -> None:
101102
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
102103
@@ -245,6 +246,7 @@ def __init__(
245246
self.e3nn_angle_conv_l_max = e3nn_angle_conv_l_max
246247
self.e3nn_angle_use_cross = e3nn_angle_use_cross
247248
self.e3nn_angle_only_single_angle = e3nn_angle_only_single_angle
249+
self.e3nn_conv_use_edge_sh_feat = e3nn_conv_use_edge_sh_feat
248250

249251
def __getitem__(self, key):
250252
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def init_subclass_params(sub_data, sub_class):
230230
use_e3nn_conv=self.repflow_args.use_e3nn_conv,
231231
e3nn_conv_pattern=self.repflow_args.e3nn_conv_pattern,
232232
use_e3nn_denominator=self.repflow_args.use_e3nn_denominator,
233+
e3nn_conv_use_edge_sh_feat=self.repflow_args.e3nn_conv_use_edge_sh_feat,
233234
use_e3nn_angle_conv=self.repflow_args.use_e3nn_angle_conv,
234235
e3nn_angle_conv_l_max=self.repflow_args.e3nn_angle_conv_l_max,
235236
e3nn_use_edge_feat_weights=self.repflow_args.e3nn_use_edge_feat_weights,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def __init__(
106106
use_e3nn_conv: bool = False,
107107
e3nn_conv_pattern: str = "128x0e+64x1e+32x2e+32x3e",
108108
e3nn_use_edge_feat_weights: bool = False,
109+
e3nn_conv_use_edge_sh_feat: bool = False,
109110
e3nn_conv_args: dict = {},
110111
e3nn_angle_conv_args: dict = {},
111112
use_e3nn_angle_conv: bool = False,
@@ -394,8 +395,9 @@ def __init__(
394395
self.e3nn_conv_pattern = e3nn_conv_pattern
395396
self.e3nn_conv_args = e3nn_conv_args
396397
self.e3nn_use_edge_feat_weights = e3nn_use_edge_feat_weights
398+
self.e3nn_conv_use_edge_sh_feat = e3nn_conv_use_edge_sh_feat
397399
if self.use_e3nn_conv:
398-
self.e3nn_conv_block = IrrepsBlock(**self.e3nn_conv_args, weight_layer_act="silu")
400+
self.e3nn_conv_block = IrrepsBlock(**self.e3nn_conv_args, e3nn_conv_use_edge_sh_feat=e3nn_conv_use_edge_sh_feat, weight_layer_act="silu")
399401
if self.update_style == "res_residual":
400402
self.n_residual.append(
401403
get_residual(
@@ -575,6 +577,10 @@ def __init__(
575577
self.e3nn_angle_use_cross = e3nn_angle_use_cross
576578
if self.use_e3nn_angle_conv:
577579
self.e3nn_angle_conv_block = IrrepsAngleBlock(**self.e3nn_angle_conv_args, weight_layer_act="silu")
580+
else:
581+
self.e3nn_angle_conv_block = None
582+
583+
if self.e3nn_conv_use_edge_sh_feat or self.use_e3nn_angle_conv:
578584
if self.update_style == "res_residual":
579585
self.e_residual.append(
580586
get_residual(
@@ -586,8 +592,7 @@ def __init__(
586592
)
587593
)
588594
residual_idx += 1
589-
else:
590-
self.e3nn_angle_conv_block = None
595+
591596

592597
# angle self message
593598
if not self.use_gated_mlp:
@@ -1482,7 +1487,11 @@ def forward(
14821487
assert edge_rbf_ebd is not None
14831488
assert edge_index is not None
14841489
edge_weights = edge_rbf_ebd if not self.e3nn_use_edge_feat_weights else edge_ebd
1485-
node_sph_embed = self.e3nn_conv_block(node_sph_embed, edge_sph, edge_weights, edge_index)
1490+
node_sph_embed, edge_sph_update = self.e3nn_conv_block(node_sph_embed, edge_sph, edge_weights, edge_index, edge_sph_embed)
1491+
1492+
if self.e3nn_conv_use_edge_sh_feat:
1493+
assert edge_sph_embed is not None
1494+
edge_sph_embed = edge_sph_embed + 0.1 * edge_sph_update
14861495
# node_sph_embed = node_sph_embed
14871496
sph_conv_update = node_sph_embed[:, :, :self.n_dim].clone() # avoid following in-place op
14881497
n_update_list.append(sph_conv_update)
@@ -1751,14 +1760,17 @@ def forward(
17511760
angle_weights_input = angle_ebd
17521761
edge_sph_embed = self.e3nn_angle_conv_block(edge_sph_embed, edge_angle_filter, angle_weights_input, angle_index, a_sw)
17531762
# node_sph_embed = node_sph_embed
1763+
1764+
if self.use_e3nn_angle_conv or self.e3nn_conv_use_edge_sh_feat:
1765+
assert edge_sph_embed is not None
17541766
edge_sph_conv_update = edge_sph_embed[:, :self.e_dim].clone() # avoid following in-place op
17551767
e_update_list.append(edge_sph_conv_update)
17561768

17571769
# update edge_ebd
17581770
e_updated = self.list_update(e_update_list, "edge")
17591771

17601772
# edge angle e3nn joint update
1761-
if self.use_e3nn_angle_conv:
1773+
if self.use_e3nn_angle_conv or self.e3nn_conv_use_edge_sh_feat:
17621774
assert edge_sph_embed is not None
17631775
edge_sph_embed[:, : self.e_dim] = e_updated
17641776

deepmd/pt/model/descriptor/repflows.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import torch
99

10-
from e3nn.o3 import Irreps, Linear
10+
from e3nn.o3 import Irreps, Linear, TensorProduct
1111
from e3nn.o3 import FullyConnectedTensorProduct as FCTP
1212
import sevenn.util as util
1313
from sevenn.nn.edge_embedding import (
@@ -184,6 +184,7 @@ def __init__(
184184
use_e3nn_conv: bool = False,
185185
e3nn_conv_pattern: str = "128x0e+64x1e+32x2e+32x3e",
186186
use_e3nn_denominator: bool = False,
187+
e3nn_conv_use_edge_sh_feat: bool = False,
187188
e3nn_conv_l_max: int = 3,
188189
e3nn_use_edge_feat_weights: bool = False,
189190
use_e3nn_angle_conv: bool = False,
@@ -488,6 +489,7 @@ def __init__(
488489
self.use_e3nn_conv = use_e3nn_conv
489490
self.e3nn_conv_pattern = e3nn_conv_pattern
490491
self.use_e3nn_denominator = use_e3nn_denominator
492+
self.e3nn_conv_use_edge_sh_feat = e3nn_conv_use_edge_sh_feat
491493
self.e3nn_conv_l_max = e3nn_conv_l_max
492494
self.e3nn_use_edge_feat_weights = e3nn_use_edge_feat_weights
493495
self.use_e3nn_angle_conv = use_e3nn_angle_conv
@@ -564,10 +566,42 @@ def __init__(
564566
self.edge_rbf_embed = None
565567
self.edge_env = None
566568
self.edge_spherical_embd = None
567-
self.irreps_filter = "0e"
569+
self.irreps_filter = Irreps("0e")
570+
571+
if self.e3nn_conv_use_edge_sh_feat:
572+
irreps_edge = Irreps(f"{self.e_dim}x0e")
573+
irreps_filter = self.irreps_filter
574+
irreps_out = ((irreps_filter * self.e_dim).sort()[0]).simplify()
575+
self.edge_sph_embd_ir = Irreps(self.e3nn_angle_conv_pattern)
576+
instructions = []
577+
irreps_mid = []
578+
weight_numel = 0
579+
for i, (mul_x, ir_x) in enumerate(irreps_edge):
580+
for j, (_, ir_filter) in enumerate(irreps_filter):
581+
for ir_out in ir_x * ir_filter:
582+
if ir_out in irreps_out: # here we drop l > lmax
583+
k = len(irreps_mid)
584+
weight_numel += mul_x * 1 # path shape
585+
irreps_mid.append((mul_x, ir_out))
586+
instructions.append((i, j, k, 'uvu', False))
587+
588+
irreps_mid = Irreps(irreps_mid)
589+
self.edge_sph_embd_init_tp = TensorProduct(
590+
irreps_in1=irreps_edge,
591+
irreps_in2=irreps_filter,
592+
irreps_out=irreps_mid,
593+
instructions=instructions,
594+
shared_weights=False,
595+
internal_weights=False,
596+
)
597+
self.edge_sph_embd_init_linear = Linear(irreps_mid, self.edge_sph_embd_ir)
598+
else:
599+
self.edge_sph_embd_init_tp = None
600+
self.edge_sph_embd_ir = Irreps(f'{self.e_dim}x0e')
601+
self.edge_sph_embd_init_linear = None
568602

569603
# for edge angle e3nn conv
570-
irreps_edge = Irreps(f'{self.e_dim}x0e')
604+
irreps_edge = self.edge_sph_embd_ir
571605
self.angle_lmax = e3nn_angle_conv_l_max
572606
self.e3nn_angle_conv_pattern = Irreps("+".join(e3nn_angle_conv_pattern.split("+")[:self.angle_lmax+1]))
573607
if self.use_e3nn_angle_conv:
@@ -699,6 +733,7 @@ def __init__(
699733
use_e3nn_conv=self.use_e3nn_conv,
700734
e3nn_conv_pattern=self.e3nn_conv_pattern,
701735
e3nn_use_edge_feat_weights=self.e3nn_use_edge_feat_weights,
736+
e3nn_conv_use_edge_sh_feat=self.e3nn_conv_use_edge_sh_feat,
702737
e3nn_conv_args=e3nn_conv_args,
703738
use_e3nn_angle_conv=self.use_e3nn_angle_conv,
704739
e3nn_angle_conv_args=e3nn_angle_conv_args,
@@ -1249,12 +1284,23 @@ def forward(
12491284
edge_sph = None
12501285
node_sph_embed = None
12511286

1287+
if not self.e3nn_conv_use_edge_sh_feat:
1288+
if self.use_e3nn_angle_conv:
1289+
edge_sph_embed = edge_ebd
1290+
else:
1291+
edge_sph_embed = None
1292+
else:
1293+
assert edge_sph is not None
1294+
assert self.edge_sph_embd_init_tp is not None
1295+
assert self.edge_sph_embd_init_linear is not None
1296+
edge_sph_embed = self.edge_sph_embd_init_tp(edge_ebd, edge_sph)
1297+
edge_sph_embed = self.edge_sph_embd_init_linear(edge_sph_embed)
1298+
12521299
if self.use_e3nn_angle_conv:
12531300
assert self.use_dynamic_sel, "e3nn conv must use dynamic sel"
12541301
assert self.edge_spherical_embd_for_angle is not None
12551302
assert self.edge_to_angle_filter_prod is not None
12561303
assert self.edge_to_angle_filter_linear is not None
1257-
edge_sph_embed = edge_ebd
12581304
edge_i_index = angle_index[:, 1]
12591305
edge_j_index = angle_index[:, 2]
12601306
if not self.e3nn_angle_use_cross:
@@ -1277,7 +1323,6 @@ def forward(
12771323
angle_weights = torch.cat([cosine_ij.unsqueeze(-1), sine_ij.unsqueeze(-1)], dim=-1)
12781324
else:
12791325
edge_angle_filter = None
1280-
edge_sph_embed = None
12811326
angle_weights = None
12821327

12831328
for idx, ll in enumerate(self.layers):

0 commit comments

Comments
 (0)