Skip to content

Commit 06f666f

Browse files
committed
add e3nn_angle_use_cross
1 parent 774339b commit 06f666f

5 files changed

Lines changed: 41 additions & 8 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
e3nn_use_edge_feat_weights: bool = False,
9696
use_e3nn_angle_conv: bool = False,
9797
e3nn_angle_conv_l_max: int = 2,
98+
e3nn_angle_use_cross: bool = False,
9899
) -> None:
99100
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
100101
@@ -241,6 +242,7 @@ def __init__(
241242
self.e3nn_use_edge_feat_weights = e3nn_use_edge_feat_weights
242243
self.use_e3nn_angle_conv = use_e3nn_angle_conv
243244
self.e3nn_angle_conv_l_max = e3nn_angle_conv_l_max
245+
self.e3nn_angle_use_cross = e3nn_angle_use_cross
244246

245247
def __getitem__(self, key):
246248
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def init_subclass_params(sub_data, sub_class):
234234
e3nn_angle_conv_l_max=self.repflow_args.e3nn_angle_conv_l_max,
235235
e3nn_use_edge_feat_weights=self.repflow_args.e3nn_use_edge_feat_weights,
236236
e3nn_conv_l_max=self.repflow_args.e3nn_conv_l_max,
237+
e3nn_angle_use_cross=self.repflow_args.e3nn_angle_use_cross,
237238
exclude_types=exclude_types,
238239
env_protection=env_protection,
239240
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def __init__(
110110
e3nn_angle_conv_args: dict = {},
111111
use_e3nn_angle_conv: bool = False,
112112
e3nn_angle_conv_pattern: str = "64x0e+32x1e+32x2e",
113+
e3nn_angle_use_cross: bool = False,
113114
activation_function: str = "silu",
114115
update_style: str = "res_residual",
115116
update_residual: float = 0.1,
@@ -571,6 +572,7 @@ def __init__(
571572
self.use_e3nn_angle_conv = use_e3nn_angle_conv
572573
self.e3nn_angle_conv_pattern = e3nn_angle_conv_pattern
573574
self.e3nn_angle_conv_args = e3nn_angle_conv_args
575+
self.e3nn_angle_use_cross = e3nn_angle_use_cross
574576
if self.use_e3nn_angle_conv:
575577
self.e3nn_angle_conv_block = IrrepsAngleBlock(**self.e3nn_angle_conv_args, weight_layer_act=self.activation_function)
576578
if self.update_style == "res_residual":
@@ -1209,6 +1211,7 @@ def forward(
12091211
node_sph_embed: Optional[torch.Tensor] = None, # nf x nloc x num_sph_node
12101212
edge_angle_filter: Optional[torch.Tensor] = None, # n_angle x num_sph
12111213
edge_sph_embed: Optional[torch.Tensor] = None, # n_edge x num_sph
1214+
angle_weights: Optional[torch.Tensor] = None, # n_angle x 8
12121215
):
12131216
"""
12141217
Parameters
@@ -1741,8 +1744,12 @@ def forward(
17411744
assert edge_sph_embed is not None
17421745
assert edge_angle_filter is not None
17431746
assert angle_index is not None
1744-
angle_weights = angle_ebd
1745-
edge_sph_embed = self.e3nn_angle_conv_block(edge_sph_embed, edge_angle_filter, angle_weights, angle_index, a_sw)
1747+
if self.e3nn_angle_use_cross:
1748+
assert angle_weights is not None
1749+
angle_weights_input = angle_weights
1750+
else:
1751+
angle_weights_input = angle_ebd
1752+
edge_sph_embed = self.e3nn_angle_conv_block(edge_sph_embed, edge_angle_filter, angle_weights_input, angle_index, a_sw)
17461753
# node_sph_embed = node_sph_embed
17471754
edge_sph_conv_update = edge_sph_embed[:, :self.e_dim].clone() # avoid following in-place op
17481755
e_update_list.append(edge_sph_conv_update)

deepmd/pt/model/descriptor/repflows.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def __init__(
189189
use_e3nn_angle_conv: bool = False,
190190
e3nn_angle_conv_l_max: int = 2,
191191
e3nn_angle_conv_pattern: str = "64x0e+32x1e+32x2e",
192+
e3nn_angle_use_cross: bool = False,
192193
seed: Optional[Union[int, list[int]]] = None,
193194
) -> None:
194195
r"""
@@ -491,6 +492,7 @@ def __init__(
491492
self.use_e3nn_angle_conv = use_e3nn_angle_conv
492493
self.e3nn_angle_conv_l_max = e3nn_angle_conv_l_max
493494
self.e3nn_angle_conv_pattern = e3nn_angle_conv_pattern
495+
self.e3nn_angle_use_cross = e3nn_angle_use_cross
494496

495497
if not self.edge_use_esen_rbf:
496498
self.edge_embd = MLPLayer(
@@ -607,7 +609,7 @@ def __init__(
607609
"irreps_out": irreps_out,
608610
"denominator": 1.0 if not self.use_e3nn_denominator else self.dynamic_e_sel / 4,
609611
"train_denominator": True,
610-
"weight_layer_input_to_hidden": [8, 64, 64] if not self.e3nn_use_edge_feat_weights else [self.e_dim],
612+
"weight_layer_input_to_hidden": [8, 64, 64] if not self.e3nn_use_edge_feat_weights else [self.e_dim, 64, 64],
611613
}
612614
irreps_x = irreps_out
613615

@@ -627,7 +629,7 @@ def __init__(
627629
"irreps_out": irreps_edge_out,
628630
"denominator": 1.0 if not self.use_e3nn_denominator else self.dynamic_a_sel / 4,
629631
"train_denominator": True,
630-
"weight_layer_input_to_hidden": [self.a_dim],
632+
"weight_layer_input_to_hidden": [8, 64, 64] if self.e3nn_angle_use_cross else [self.a_dim],
631633
}
632634
irreps_edge = irreps_edge_out
633635

@@ -699,6 +701,7 @@ def __init__(
699701
use_e3nn_angle_conv=self.use_e3nn_angle_conv,
700702
e3nn_angle_conv_args=e3nn_angle_conv_args,
701703
e3nn_angle_conv_pattern=self.e3nn_angle_conv_pattern,
704+
e3nn_angle_use_cross=self.e3nn_angle_use_cross,
702705
seed=child_seed(child_seed(seed, 1), ii),
703706
)
704707
)
@@ -1249,15 +1252,28 @@ def forward(
12491252
assert self.edge_spherical_embd_for_angle is not None
12501253
assert self.edge_to_angle_filter_prod is not None
12511254
assert self.edge_to_angle_filter_linear is not None
1252-
edge_sph_for_angle = self.edge_spherical_embd_for_angle(diff)
1255+
edge_sph_embed = edge_ebd
12531256
edge_i_index = angle_index[:, 1]
12541257
edge_j_index = angle_index[:, 2]
1255-
edge_angle_filter_tp = self.edge_to_angle_filter_prod(edge_sph_for_angle[edge_i_index], edge_sph_for_angle[edge_j_index])
1256-
edge_angle_filter = self.edge_to_angle_filter_linear(edge_angle_filter_tp)
1257-
edge_sph_embed = edge_ebd
1258+
if not self.e3nn_angle_use_cross:
1259+
edge_sph_for_angle = self.edge_spherical_embd_for_angle(diff)
1260+
edge_angle_filter_tp = self.edge_to_angle_filter_prod(edge_sph_for_angle[edge_i_index], edge_sph_for_angle[edge_j_index])
1261+
edge_angle_filter = self.edge_to_angle_filter_linear(edge_angle_filter_tp)
1262+
angle_weights = None
1263+
else:
1264+
angle_cross_vec = torch.cross(diff[edge_i_index], diff[edge_j_index], dim=-1)
1265+
edge_angle_filter = self.edge_spherical_embd_for_angle(angle_cross_vec)
1266+
# angle basis as weights
1267+
# 1 - 1e-6 for torch.acos stability
1268+
cosine_ij = cosine_ij[a_nlist_mask]
1269+
sine_ij = torch.sqrt(1 - cosine_ij ** 2)
1270+
theta = torch.acos(cosine_ij).unsqueeze(-1)
1271+
theta_list = torch.cat([theta*2, theta*4, theta*8], dim=-1)
1272+
angle_weights = torch.cat([cosine_ij.unsqueeze(-1), torch.cos(theta_list), sine_ij.unsqueeze(-1), torch.sin(theta_list)], dim=-1)
12581273
else:
12591274
edge_angle_filter = None
12601275
edge_sph_embed = None
1276+
angle_weights = None
12611277

12621278
for idx, ll in enumerate(self.layers):
12631279
# node_ebd: nb x nloc x n_dim
@@ -1379,6 +1395,7 @@ def forward(
13791395
node_sph_embed=node_sph_embed,
13801396
edge_angle_filter=edge_angle_filter,
13811397
edge_sph_embed=edge_sph_embed,
1398+
angle_weights=angle_weights,
13821399
)
13831400
# may cause jit slow, todo fix
13841401
elif not self.rk_update_diff_layer:

deepmd/utils/argcheck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,6 +1983,12 @@ def dpa3_repflow_args():
19831983
optional=True,
19841984
default=2,
19851985
),
1986+
Argument(
1987+
"e3nn_angle_use_cross",
1988+
bool,
1989+
optional=True,
1990+
default=False,
1991+
),
19861992
]
19871993

19881994

0 commit comments

Comments
 (0)