Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class RepFlowArgs:
smooth_edge_update : bool, optional
Whether to make edge update smooth.
If True, the edge update from angle message will not use self as padding.
edge_init_use_dist : bool, optional
Whether to use direct distance r to initialize the edge features instead of 1/r.
Note that when using this option, the activation function will not be used when initializing edge features.
use_dynamic_sel : bool, optional
Whether to dynamically select neighbors within the cutoff radius.
If True, the exact number of neighbors within the cutoff radius is used
Expand Down Expand Up @@ -162,6 +165,7 @@ def __init__(
skip_stat: bool = False,
optim_update: bool = True,
smooth_edge_update: bool = False,
edge_init_use_dist: bool = False,
use_dynamic_sel: bool = False,
sel_reduce_factor: float = 10.0,
) -> None:
Expand Down Expand Up @@ -190,6 +194,7 @@ def __init__(
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update
self.edge_init_use_dist = edge_init_use_dist
self.use_dynamic_sel = use_dynamic_sel
self.sel_reduce_factor = sel_reduce_factor

Expand Down Expand Up @@ -223,6 +228,7 @@ def serialize(self) -> dict:
"fix_stat_std": self.fix_stat_std,
"optim_update": self.optim_update,
"smooth_edge_update": self.smooth_edge_update,
"edge_init_use_dist": self.edge_init_use_dist,
"use_dynamic_sel": self.use_dynamic_sel,
"sel_reduce_factor": self.sel_reduce_factor,
}
Expand Down Expand Up @@ -321,6 +327,7 @@ def init_subclass_params(sub_data, sub_class):
fix_stat_std=self.repflow_args.fix_stat_std,
optim_update=self.repflow_args.optim_update,
smooth_edge_update=self.repflow_args.smooth_edge_update,
edge_init_use_dist=self.repflow_args.edge_init_use_dist,
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
exclude_types=exclude_types,
Expand Down
24 changes: 24 additions & 0 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@
smooth_edge_update : bool, optional
Whether to make edge update smooth.
If True, the edge update from angle message will not use self as padding.
edge_init_use_dist : bool, optional
Whether to use direct distance r to initialize the edge features instead of 1/r.
Note that when using this option, the activation function will not be used when initializing edge features.
use_dynamic_sel : bool, optional
Whether to dynamically select neighbors within the cutoff radius.
If True, the exact number of neighbors within the cutoff radius is used
Expand Down Expand Up @@ -185,6 +188,7 @@
fix_stat_std: float = 0.3,
optim_update: bool = True,
smooth_edge_update: bool = False,
edge_init_use_dist: bool = False,
use_dynamic_sel: bool = False,
sel_reduce_factor: float = 10.0,
seed: Optional[Union[int, list[int]]] = None,
Expand Down Expand Up @@ -218,6 +222,7 @@
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update
self.edge_init_use_dist = edge_init_use_dist
self.use_dynamic_sel = use_dynamic_sel
self.sel_reduce_factor = sel_reduce_factor
if self.use_dynamic_sel and not self.smooth_edge_update:
Expand Down Expand Up @@ -459,6 +464,24 @@
# beyond the cutoff sw should be 0.0
sw = xp.where(nlist_mask, sw, xp.zeros_like(sw))

# nb x nloc x tebd_dim
atype_embd = atype_embd_ext[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]

node_ebd = self.act(atype_embd)
Comment thread Fixed
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
# edge_input, h2 = xp.split(dmatrix, [1], axis=-1)
edge_input = dmatrix[:, :, :, :1]
Comment thread
iProzd marked this conversation as resolved.
Outdated
h2 = dmatrix[:, :, :, 1:]
Comment thread Fixed
if self.edge_init_use_dist:
# nb x nloc x nnei x 1
edge_input = xp.linalg.vector_norm(diff, axis=-1, keepdims=True)
# nb x nloc x nnei x e_dim
edge_ebd = self.edge_embd(edge_input)
Comment thread Fixed
else:
# nb x nloc x nnei x e_dim
edge_ebd = self.act(self.edge_embd(edge_input))
Comment thread Fixed

# get angle nlist (maybe smaller)
a_dist_mask = (xp.linalg.vector_norm(diff, axis=-1) < self.a_rcut)[
:, :, : self.a_sel
Expand Down Expand Up @@ -647,6 +670,7 @@
"precision": self.precision,
"fix_stat_std": self.fix_stat_std,
"optim_update": self.optim_update,
"edge_init_use_dist": self.edge_init_use_dist,
"smooth_edge_update": self.smooth_edge_update,
"use_dynamic_sel": self.use_dynamic_sel,
"sel_reduce_factor": self.sel_reduce_factor,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def init_subclass_params(sub_data, sub_class):
fix_stat_std=self.repflow_args.fix_stat_std,
optim_update=self.repflow_args.optim_update,
smooth_edge_update=self.repflow_args.smooth_edge_update,
edge_init_use_dist=self.repflow_args.edge_init_use_dist,
use_dynamic_sel=self.repflow_args.use_dynamic_sel,
sel_reduce_factor=self.repflow_args.sel_reduce_factor,
exclude_types=exclude_types,
Expand Down
26 changes: 26 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ class DescrptBlockRepflows(DescriptorBlock):
smooth_edge_update : bool, optional
Whether to make edge update smooth.
If True, the edge update from angle message will not use self as padding.
edge_init_use_dist : bool, optional
Whether to use direct distance r to initialize the edge features instead of 1/r.
Note that when using this option, the activation function will not be used when initializing edge features.
use_dynamic_sel : bool, optional
Whether to dynamically select neighbors within the cutoff radius.
If True, the exact number of neighbors within the cutoff radius is used
Expand Down Expand Up @@ -198,6 +201,7 @@ def __init__(
precision: str = "float64",
fix_stat_std: float = 0.3,
smooth_edge_update: bool = False,
edge_init_use_dist: bool = False,
use_dynamic_sel: bool = False,
sel_reduce_factor: float = 10.0,
optim_update: bool = True,
Expand Down Expand Up @@ -232,6 +236,7 @@ def __init__(
self.a_compress_use_split = a_compress_use_split
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update
self.edge_init_use_dist = edge_init_use_dist
self.use_dynamic_sel = use_dynamic_sel
self.sel_reduce_factor = sel_reduce_factor
if self.use_dynamic_sel and not self.smooth_edge_update:
Expand Down Expand Up @@ -431,6 +436,27 @@ def forward(
# beyond the cutoff sw should be 0.0
sw = sw.masked_fill(~nlist_mask, 0.0)

# [nframes, nloc, tebd_dim]
if comm_dict is None:
assert isinstance(extended_atype_embd, torch.Tensor) # for jit
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
else:
atype_embd = extended_atype_embd
assert isinstance(atype_embd, torch.Tensor) # for jit
node_ebd = self.act(atype_embd)
n_dim = node_ebd.shape[-1]
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1)
if self.edge_init_use_dist:
# nb x nloc x nnei x 1
edge_input = torch.linalg.norm(diff, dim=-1, keepdim=True)
# nb x nloc x nnei x e_dim
edge_ebd = self.edge_embd(edge_input)
else:
# nb x nloc x nnei x e_dim
edge_ebd = self.act(self.edge_embd(edge_input))

# get angle nlist (maybe smaller)
a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[
:, :, : self.a_sel
Expand Down
11 changes: 11 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,6 +1497,9 @@ def dpa3_repflow_args():
"Whether to make edge update smooth. "
"If True, the edge update from angle message will not use self as padding."
)
doc_edge_init_use_dist = (
"Whether to use direct distance r to initialize the edge features instead of 1/r. "
"Note that when using this option, the activation function will not be used when initializing edge features."
doc_use_dynamic_sel = (
"Whether to dynamically select neighbors within the cutoff radius. "
"If True, the exact number of neighbors within the cutoff radius is used "
Expand Down Expand Up @@ -1611,6 +1614,14 @@ def dpa3_repflow_args():
default=False, # For compatability. This will be True in the future
doc=doc_smooth_edge_update,
),
Argument(
"edge_init_use_dist",
bool,
optional=True,
default=False,
alias=["edge_use_dist"],
doc=doc_edge_init_use_dist,
),
Argument(
"use_dynamic_sel",
bool,
Expand Down
15 changes: 14 additions & 1 deletion source/tests/consistent/descriptor/test_dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
(1, 2), # a_compress_e_rate
(True,), # a_compress_use_split
(True, False), # optim_update
(True, False), # edge_init_use_dist
(True, False), # use_dynamic_sel
(0.3, 0.0), # fix_stat_std
(1, 2), # n_multi_edge_message
Expand All @@ -81,6 +82,7 @@
a_compress_e_rate,
a_compress_use_split,
optim_update,
edge_init_use_dist,
use_dynamic_sel,
fix_stat_std,
n_multi_edge_message,
Expand All @@ -105,6 +107,7 @@
"a_compress_e_rate": a_compress_e_rate,
"a_compress_use_split": a_compress_use_split,
"optim_update": optim_update,
"edge_init_use_dist": edge_init_use_dist,
"use_dynamic_sel": use_dynamic_sel,
"smooth_edge_update": True,
"fix_stat_std": fix_stat_std,
Expand Down Expand Up @@ -134,6 +137,7 @@
a_compress_e_rate,
a_compress_use_split,
optim_update,
edge_init_use_dist,
Comment thread Fixed
Comment thread Fixed

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable edge_init_use_dist is not used.
use_dynamic_sel,
fix_stat_std,
n_multi_edge_message,
Expand All @@ -151,13 +155,17 @@
a_compress_e_rate,
a_compress_use_split,
optim_update,
edge_init_use_dist,
use_dynamic_sel,
fix_stat_std,
n_multi_edge_message,
precision,
) = self.param
return (
not INSTALLED_PD or precision == "bfloat16" or use_dynamic_sel
not INSTALLED_PD
or precision == "bfloat16"
or edge_init_use_dist
or use_dynamic_sel
) # not supported yet

@property
Expand All @@ -170,6 +178,7 @@
a_compress_e_rate,
a_compress_use_split,
optim_update,
edge_init_use_dist,
Comment thread Fixed
Comment thread Fixed

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable edge_init_use_dist is not used.
use_dynamic_sel,
fix_stat_std,
n_multi_edge_message,
Expand All @@ -187,6 +196,7 @@
a_compress_e_rate,
a_compress_use_split,
optim_update,
edge_init_use_dist,
Comment thread Fixed
Comment thread Fixed

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable edge_init_use_dist is not used.
use_dynamic_sel,
fix_stat_std,
n_multi_edge_message,
Expand Down Expand Up @@ -246,6 +256,7 @@
a_compress_e_rate,
a_compress_use_split,
optim_update,
edge_init_use_dist,
Comment thread Fixed
Comment thread Fixed

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable edge_init_use_dist is not used.
use_dynamic_sel,
fix_stat_std,
n_multi_edge_message,
Expand Down Expand Up @@ -326,6 +337,7 @@
a_compress_e_rate,
a_compress_use_split,
optim_update,
edge_init_use_dist,
use_dynamic_sel,
fix_stat_std,
n_multi_edge_message,
Expand All @@ -349,6 +361,7 @@
a_compress_e_rate,
a_compress_use_split,
optim_update,
edge_init_use_dist,
use_dynamic_sel,
fix_stat_std,
n_multi_edge_message,
Expand Down
5 changes: 4 additions & 1 deletion source/tests/universal/dpmodel/descriptor/test_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def DescriptorParamDPA3(
a_compress_use_split=False,
optim_update=True,
smooth_edge_update=False,
edge_init_use_dist=False,
fix_stat_std=0.3,
use_dynamic_sel=False,
precision="float64",
Expand All @@ -505,6 +506,7 @@ def DescriptorParamDPA3(
"a_compress_use_split": a_compress_use_split,
"optim_update": optim_update,
"smooth_edge_update": smooth_edge_update,
"edge_init_use_dist": edge_init_use_dist,
"fix_stat_std": fix_stat_std,
"n_multi_edge_message": n_multi_edge_message,
"axis_neuron": 2,
Expand Down Expand Up @@ -543,8 +545,9 @@ def DescriptorParamDPA3(
"a_compress_use_split": (True,),
"optim_update": (True, False),
"smooth_edge_update": (True,),
"edge_init_use_dist": (True, False),
"fix_stat_std": (0.3,),
"n_multi_edge_message": (1, 2),
"n_multi_edge_message": (1,),
"use_dynamic_sel": (True, False),
"env_protection": (0.0, 1e-8),
"precision": ("float64",),
Expand Down
Loading