Skip to content

Commit 774339b

Browse files
committed
add use_e3nn_angle_conv
1 parent 603b1b0 commit 774339b

6 files changed

Lines changed: 308 additions & 4 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def __init__(
9393
use_e3nn_denominator: bool = False,
9494
e3nn_conv_l_max: int = 3,
9595
e3nn_use_edge_feat_weights: bool = False,
96+
use_e3nn_angle_conv: bool = False,
97+
e3nn_angle_conv_l_max: int = 2,
9698
) -> None:
9799
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
98100
@@ -237,6 +239,8 @@ def __init__(
237239
self.use_e3nn_denominator = use_e3nn_denominator
238240
self.e3nn_conv_l_max = e3nn_conv_l_max
239241
self.e3nn_use_edge_feat_weights = e3nn_use_edge_feat_weights
242+
self.use_e3nn_angle_conv = use_e3nn_angle_conv
243+
self.e3nn_angle_conv_l_max = e3nn_angle_conv_l_max
240244

241245
def __getitem__(self, key):
242246
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def init_subclass_params(sub_data, sub_class):
230230
use_e3nn_conv=self.repflow_args.use_e3nn_conv,
231231
e3nn_conv_pattern=self.repflow_args.e3nn_conv_pattern,
232232
use_e3nn_denominator=self.repflow_args.use_e3nn_denominator,
233+
use_e3nn_angle_conv=self.repflow_args.use_e3nn_angle_conv,
234+
e3nn_angle_conv_l_max=self.repflow_args.e3nn_angle_conv_l_max,
233235
e3nn_use_edge_feat_weights=self.repflow_args.e3nn_use_edge_feat_weights,
234236
e3nn_conv_l_max=self.repflow_args.e3nn_conv_l_max,
235237
exclude_types=exclude_types,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from deepmd.pt.model.network.e3nn_networks import (
4646
IrrepsBlock,
47+
IrrepsAngleBlock,
4748
)
4849

4950

@@ -106,6 +107,9 @@ def __init__(
106107
e3nn_conv_pattern: str = "128x0e+64x1e+32x2e+32x3e",
107108
e3nn_use_edge_feat_weights: bool = False,
108109
e3nn_conv_args: dict = {},
110+
e3nn_angle_conv_args: dict = {},
111+
use_e3nn_angle_conv: bool = False,
112+
e3nn_angle_conv_pattern: str = "64x0e+32x1e+32x2e",
109113
activation_function: str = "silu",
110114
update_style: str = "res_residual",
111115
update_residual: float = 0.1,
@@ -563,6 +567,26 @@ def __init__(
563567
)
564568
residual_idx += 1
565569

570+
# for edge angle e3nn conv
571+
self.use_e3nn_angle_conv = use_e3nn_angle_conv
572+
self.e3nn_angle_conv_pattern = e3nn_angle_conv_pattern
573+
self.e3nn_angle_conv_args = e3nn_angle_conv_args
574+
if self.use_e3nn_angle_conv:
575+
self.e3nn_angle_conv_block = IrrepsAngleBlock(**self.e3nn_angle_conv_args, weight_layer_act=self.activation_function)
576+
if self.update_style == "res_residual":
577+
self.e_residual.append(
578+
get_residual(
579+
e_dim,
580+
self.update_residual * self.residual_pref[residual_idx],
581+
self.update_residual_init,
582+
precision=precision,
583+
seed=child_seed(seed, 22),
584+
)
585+
)
586+
residual_idx += 1
587+
else:
588+
self.e3nn_angle_conv_block = None
589+
566590
# angle self message
567591
if not self.use_gated_mlp:
568592
self.angle_self_linear = MLPLayer(
@@ -1183,6 +1207,8 @@ def forward(
11831207
edge_rbf_ebd: Optional[torch.Tensor] = None, # n_edge x num_basis
11841208
edge_sph: Optional[torch.Tensor] = None, # n_edge x num_sph
11851209
node_sph_embed: Optional[torch.Tensor] = None, # nf x nloc x num_sph_node
1210+
edge_angle_filter: Optional[torch.Tensor] = None, # n_angle x num_sph
1211+
edge_sph_embed: Optional[torch.Tensor] = None, # n_edge x num_sph
11861212
):
11871213
"""
11881214
Parameters
@@ -1708,9 +1734,27 @@ def forward(
17081734
padding_edge_angle_update = self.EAM_rmsnorm(padding_edge_angle_update)
17091735

17101736
e_update_list.append(padding_edge_angle_update)
1737+
1738+
# for edge angle e3nn conv
1739+
if self.use_e3nn_angle_conv:
1740+
assert self.e3nn_angle_conv_block is not None
1741+
assert edge_sph_embed is not None
1742+
assert edge_angle_filter is not None
1743+
assert angle_index is not None
1744+
angle_weights = angle_ebd
1745+
edge_sph_embed = self.e3nn_angle_conv_block(edge_sph_embed, edge_angle_filter, angle_weights, angle_index, a_sw)
1746+
# node_sph_embed = node_sph_embed
1747+
edge_sph_conv_update = edge_sph_embed[:, :self.e_dim].clone() # avoid following in-place op
1748+
e_update_list.append(edge_sph_conv_update)
1749+
17111750
# update edge_ebd
17121751
e_updated = self.list_update(e_update_list, "edge")
17131752

1753+
# edge angle e3nn joint update
1754+
if self.use_e3nn_angle_conv:
1755+
assert edge_sph_embed is not None
1756+
edge_sph_embed[:, : self.e_dim] = e_updated
1757+
17141758
# angle self message
17151759
# nb x nloc x a_nnei x a_nnei x dim_a
17161760
if not self.optim_update:
@@ -1927,7 +1971,7 @@ def forward(
19271971

19281972
# update angle_ebd
19291973
a_updated = self.list_update(a_update_list, "angle")
1930-
return n_updated, e_updated, a_updated, d_updated, node_sph_embed
1974+
return n_updated, e_updated, a_updated, d_updated, node_sph_embed, edge_sph_embed
19311975

19321976
@torch.jit.export
19331977
def list_update_res_avg(

deepmd/pt/model/descriptor/repflows.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
import torch
99

10-
from e3nn.o3 import Irreps
10+
from e3nn.o3 import Irreps, Linear
11+
from e3nn.o3 import FullyConnectedTensorProduct as FCTP
1112
import sevenn.util as util
1213
from sevenn.nn.edge_embedding import (
1314
SphericalEncoding,
@@ -185,6 +186,9 @@ def __init__(
185186
use_e3nn_denominator: bool = False,
186187
e3nn_conv_l_max: int = 3,
187188
e3nn_use_edge_feat_weights: bool = False,
189+
use_e3nn_angle_conv: bool = False,
190+
e3nn_angle_conv_l_max: int = 2,
191+
e3nn_angle_conv_pattern: str = "64x0e+32x1e+32x2e",
188192
seed: Optional[Union[int, list[int]]] = None,
189193
) -> None:
190194
r"""
@@ -484,6 +488,9 @@ def __init__(
484488
self.use_e3nn_denominator = use_e3nn_denominator
485489
self.e3nn_conv_l_max = e3nn_conv_l_max
486490
self.e3nn_use_edge_feat_weights = e3nn_use_edge_feat_weights
491+
self.use_e3nn_angle_conv = use_e3nn_angle_conv
492+
self.e3nn_angle_conv_l_max = e3nn_angle_conv_l_max
493+
self.e3nn_angle_conv_pattern = e3nn_angle_conv_pattern
487494

488495
if not self.edge_use_esen_rbf:
489496
self.edge_embd = MLPLayer(
@@ -540,6 +547,7 @@ def __init__(
540547
self.force_embedding_linear = None
541548

542549
layers = []
550+
# for node edge e3nn conv
543551
irreps_x = Irreps(f'{self.n_dim}x0e')
544552
self.lmax = e3nn_conv_l_max
545553
self.e3nn_conv_pattern = Irreps("+".join(e3nn_conv_pattern.split("+")[:self.lmax+1]))
@@ -554,7 +562,36 @@ def __init__(
554562
self.edge_spherical_embd = None
555563
self.irreps_filter = "0e"
556564

565+
# for edge angle e3nn conv
566+
irreps_edge = Irreps(f'{self.e_dim}x0e')
567+
self.angle_lmax = e3nn_angle_conv_l_max
568+
self.e3nn_angle_conv_pattern = Irreps("+".join(e3nn_angle_conv_pattern.split("+")[:self.angle_lmax+1]))
569+
if self.use_e3nn_angle_conv:
570+
self.edge_spherical_embd_for_angle = SphericalEncoding(self.angle_lmax, parity=1, normalize=True)
571+
self.irreps_angle_filter = self.edge_spherical_embd_for_angle.irreps_out
572+
self.angle_edge_filter_out = util.infer_irreps_out(
573+
self.irreps_angle_filter, # type: ignore
574+
self.irreps_angle_filter,
575+
self.angle_lmax, # type: ignore
576+
'full',
577+
False,
578+
)
579+
# edge_i x edge_j --> linear --> angle_filter
580+
self.edge_to_angle_filter_prod = FCTP(self.irreps_angle_filter, self.irreps_angle_filter, self.angle_edge_filter_out)
581+
self.edge_to_angle_filter_linear = Linear(
582+
irreps_in=self.angle_edge_filter_out,
583+
irreps_out=self.irreps_angle_filter,
584+
biases=False,
585+
)
586+
else:
587+
self.edge_spherical_embd_for_angle = None
588+
self.irreps_angle_filter = "0e"
589+
self.edge_to_angle_filter_prod = None
590+
self.edge_to_angle_filter_linear = None
591+
592+
557593
for ii in range(nlayers):
594+
# for node edge e3nn conv
558595
irreps_out = Irreps(self.e3nn_conv_pattern)
559596
irreps_out_tp = util.infer_irreps_out(
560597
irreps_x, # type: ignore
@@ -573,6 +610,27 @@ def __init__(
573610
"weight_layer_input_to_hidden": [8, 64, 64] if not self.e3nn_use_edge_feat_weights else [self.e_dim],
574611
}
575612
irreps_x = irreps_out
613+
614+
# for edge angle e3nn conv
615+
irreps_edge_out = Irreps(self.e3nn_angle_conv_pattern)
616+
irreps_out_tp = util.infer_irreps_out(
617+
irreps_edge, # type: ignore
618+
self.irreps_angle_filter,
619+
irreps_edge_out.lmax, # type: ignore
620+
'full',
621+
False,
622+
)
623+
e3nn_angle_conv_args = {
624+
"irreps_x": irreps_edge,
625+
"irreps_filter": self.irreps_angle_filter,
626+
"irreps_out_tp": irreps_out_tp,
627+
"irreps_out": irreps_edge_out,
628+
"denominator": 1.0 if not self.use_e3nn_denominator else self.dynamic_a_sel / 4,
629+
"train_denominator": True,
630+
"weight_layer_input_to_hidden": [self.a_dim],
631+
}
632+
irreps_edge = irreps_edge_out
633+
576634
layers.append(
577635
RepFlowLayer(
578636
e_rcut=self.e_rcut,
@@ -638,6 +696,9 @@ def __init__(
638696
e3nn_conv_pattern=self.e3nn_conv_pattern,
639697
e3nn_use_edge_feat_weights=self.e3nn_use_edge_feat_weights,
640698
e3nn_conv_args=e3nn_conv_args,
699+
use_e3nn_angle_conv=self.use_e3nn_angle_conv,
700+
e3nn_angle_conv_args=e3nn_angle_conv_args,
701+
e3nn_angle_conv_pattern=self.e3nn_angle_conv_pattern,
641702
seed=child_seed(child_seed(seed, 1), ii),
642703
)
643704
)
@@ -1183,6 +1244,21 @@ def forward(
11831244
edge_sph = None
11841245
node_sph_embed = None
11851246

1247+
if self.use_e3nn_angle_conv:
1248+
assert self.use_dynamic_sel, "e3nn conv must use dynamic sel"
1249+
assert self.edge_spherical_embd_for_angle is not None
1250+
assert self.edge_to_angle_filter_prod is not None
1251+
assert self.edge_to_angle_filter_linear is not None
1252+
edge_sph_for_angle = self.edge_spherical_embd_for_angle(diff)
1253+
edge_i_index = angle_index[:, 1]
1254+
edge_j_index = angle_index[:, 2]
1255+
edge_angle_filter_tp = self.edge_to_angle_filter_prod(edge_sph_for_angle[edge_i_index], edge_sph_for_angle[edge_j_index])
1256+
edge_angle_filter = self.edge_to_angle_filter_linear(edge_angle_filter_tp)
1257+
edge_sph_embed = edge_ebd
1258+
else:
1259+
edge_angle_filter = None
1260+
edge_sph_embed = None
1261+
11861262
for idx, ll in enumerate(self.layers):
11871263
# node_ebd: nb x nloc x n_dim
11881264
# node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parrallel_mode
@@ -1256,7 +1332,7 @@ def forward(
12561332
# for jit
12571333
assert not self.use_rk_update
12581334
assert dihedral_ebd is not None
1259-
node_ebd, edge_ebd, angle_ebd, dihedral_ebd, ___ = ll.forward(
1335+
node_ebd, edge_ebd, angle_ebd, dihedral_ebd, ___, ___, = ll.forward(
12601336
node_ebd_ext,
12611337
edge_ebd,
12621338
h2,
@@ -1279,7 +1355,7 @@ def forward(
12791355
else:
12801356
assert dihedral_ebd is None
12811357
if not self.use_rk_update:
1282-
node_ebd, edge_ebd, angle_ebd, ___, node_sph_embed = ll.forward(
1358+
node_ebd, edge_ebd, angle_ebd, ___, node_sph_embed, edge_sph_embed = ll.forward(
12831359
node_ebd_ext,
12841360
edge_ebd,
12851361
h2,
@@ -1301,6 +1377,8 @@ def forward(
13011377
edge_rbf_ebd=edge_rbf_ebd,
13021378
edge_sph=edge_sph,
13031379
node_sph_embed=node_sph_embed,
1380+
edge_angle_filter=edge_angle_filter,
1381+
edge_sph_embed=edge_sph_embed,
13041382
)
13051383
# may cause jit slow, todo fix
13061384
elif not self.rk_update_diff_layer:
@@ -1321,6 +1399,7 @@ def forward(
13211399
angle_ebd_k1,
13221400
___,
13231401
___,
1402+
___,
13241403
) = ll.forward(
13251404
node_ebd_k_in,
13261405
edge_ebd_k_in,

0 commit comments

Comments
 (0)