Skip to content

Commit e9e39ad

Browse files
committed
feat add dist edge
1 parent 3587d07 commit e9e39ad

4 files changed

Lines changed: 20 additions & 1 deletion

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def __init__(
151151
skip_stat: bool = False,
152152
optim_update: bool = True,
153153
smooth_edge_update: bool = False,
154+
edge_init_use_dist: bool = False,
154155
use_exp_switch: bool = False,
155156
use_dynamic_sel: bool = False,
156157
sel_reduce_factor: float = 10.0,
@@ -180,6 +181,7 @@ def __init__(
180181
self.a_compress_use_split = a_compress_use_split
181182
self.optim_update = optim_update
182183
self.smooth_edge_update = smooth_edge_update
184+
self.edge_init_use_dist = edge_init_use_dist
183185
self.use_exp_switch = use_exp_switch
184186
self.use_dynamic_sel = use_dynamic_sel
185187
self.sel_reduce_factor = sel_reduce_factor

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def init_subclass_params(sub_data, sub_class):
150150
fix_stat_std=self.repflow_args.fix_stat_std,
151151
optim_update=self.repflow_args.optim_update,
152152
smooth_edge_update=self.repflow_args.smooth_edge_update,
153+
edge_init_use_dist=self.repflow_args.edge_init_use_dist,
153154
use_exp_switch=self.repflow_args.use_exp_switch,
154155
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
155156
sel_reduce_factor=self.repflow_args.sel_reduce_factor,

deepmd/pt/model/descriptor/repflows.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def __init__(
186186
precision: str = "float64",
187187
fix_stat_std: float = 0.3,
188188
smooth_edge_update: bool = False,
189+
edge_init_use_dist: bool = False,
189190
use_exp_switch: bool = False,
190191
use_dynamic_sel: bool = False,
191192
sel_reduce_factor: float = 10.0,
@@ -221,6 +222,7 @@ def __init__(
221222
self.a_compress_use_split = a_compress_use_split
222223
self.optim_update = optim_update
223224
self.smooth_edge_update = smooth_edge_update
225+
self.edge_init_use_dist = edge_init_use_dist
224226
self.use_exp_switch = use_exp_switch
225227
self.use_dynamic_sel = use_dynamic_sel
226228
self.sel_reduce_factor = sel_reduce_factor
@@ -450,6 +452,10 @@ def forward(
450452
# get edge and angle embedding input
451453
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
452454
edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1)
455+
if self.edge_init_use_dist:
456+
# nb x nloc x nnei x 1
457+
edge_input = torch.linalg.norm(diff, dim=-1, keepdim=True)
458+
453459
# nf x nloc x a_nnei x 3
454460
normalized_diff_i = a_diff / (
455461
torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6
@@ -486,7 +492,10 @@ def forward(
486492
)
487493
# get edge and angle embedding
488494
# nb x nloc x nnei x e_dim [OR] n_edge x e_dim
489-
edge_ebd = self.act(self.edge_embd(edge_input))
495+
if not self.edge_init_use_dist:
496+
edge_ebd = self.act(self.edge_embd(edge_input))
497+
else:
498+
edge_ebd = self.edge_embd(edge_input)
490499
# nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim
491500
angle_ebd = self.angle_embd(angle_input)
492501

deepmd/utils/argcheck.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,6 +1597,13 @@ def dpa3_repflow_args():
15971597
default=False, # For compatability. This will be True in the future
15981598
doc=doc_smooth_edge_update,
15991599
),
1600+
Argument(
1601+
"edge_init_use_dist",
1602+
bool,
1603+
optional=True,
1604+
default=False,
1605+
alias=["edge_use_dist"],
1606+
),
16001607
Argument(
16011608
"use_exp_switch",
16021609
bool,

0 commit comments

Comments
 (0)