diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index b2f27195a8..a258d1892b 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -123,6 +123,9 @@ class RepFlowArgs: smooth_edge_update : bool, optional Whether to make edge update smooth. If True, the edge update from angle message will not use self as padding. + edge_init_use_dist : bool, optional + Whether to use direct distance r to initialize the edge features instead of 1/r. + Note that when using this option, the activation function will not be used when initializing edge features. use_exp_switch : bool, optional Whether to use an exponential switch function instead of a polynomial one in the neighbor update. The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance @@ -170,6 +173,7 @@ def __init__( skip_stat: bool = False, optim_update: bool = True, smooth_edge_update: bool = False, + edge_init_use_dist: bool = False, use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, @@ -199,6 +203,7 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.edge_init_use_dist = edge_init_use_dist self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor @@ -233,6 +238,7 @@ def serialize(self) -> dict: "fix_stat_std": self.fix_stat_std, "optim_update": self.optim_update, "smooth_edge_update": self.smooth_edge_update, + "edge_init_use_dist": self.edge_init_use_dist, "use_exp_switch": self.use_exp_switch, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, @@ -332,6 +338,7 @@ def init_subclass_params(sub_data, sub_class): fix_stat_std=self.repflow_args.fix_stat_std, optim_update=self.repflow_args.optim_update, smooth_edge_update=self.repflow_args.smooth_edge_update, + edge_init_use_dist=self.repflow_args.edge_init_use_dist, use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index df0b81d9d2..926b500645 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -125,6 +125,9 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock): smooth_edge_update : bool, optional Whether to make edge update smooth. If True, the edge update from angle message will not use self as padding. + edge_init_use_dist : bool, optional + Whether to use direct distance r to initialize the edge features instead of 1/r. + Note that when using this option, the activation function will not be used when initializing edge features. use_exp_switch : bool, optional Whether to use an exponential switch function instead of a polynomial one in the neighbor update. The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance @@ -193,6 +196,7 @@ def __init__( fix_stat_std: float = 0.3, optim_update: bool = True, smooth_edge_update: bool = False, + edge_init_use_dist: bool = False, use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, @@ -227,6 +231,7 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.edge_init_use_dist = edge_init_use_dist self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor @@ -510,7 +515,11 @@ def call( # get edge and angle embedding input # nb x nloc x nnei x 1, nb x nloc x nnei x 3 # edge_input, h2 = xp.split(dmatrix, [1], axis=-1) - edge_input = dmatrix[:, :, :, :1] + # nb x nloc x nnei x 1 + if self.edge_init_use_dist: + edge_input = xp.linalg.vector_norm(diff, axis=-1, keepdims=True) + else: + edge_input = dmatrix[:, :, :, :1] h2 = dmatrix[:, :, :, 1:] # nf x nloc x a_nnei x 3 @@ -552,7 +561,10 @@ def call( # get edge and angle embedding # nb x nloc x nnei x e_dim [OR] n_edge x e_dim - edge_ebd = self.act(self.edge_embd(edge_input)) + if not self.edge_init_use_dist: + edge_ebd = self.act(self.edge_embd(edge_input)) + else: + edge_ebd = self.edge_embd(edge_input) # nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim angle_ebd = self.angle_embd(angle_input) @@ -663,6 +675,7 @@ def serialize(self): "precision": self.precision, "fix_stat_std": self.fix_stat_std, "optim_update": self.optim_update, + "edge_init_use_dist": self.edge_init_use_dist, "use_exp_switch": self.use_exp_switch, "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index de7b25749d..16e9022baf 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -150,6 +150,7 @@ def init_subclass_params(sub_data, sub_class): fix_stat_std=self.repflow_args.fix_stat_std, optim_update=self.repflow_args.optim_update, smooth_edge_update=self.repflow_args.smooth_edge_update, + edge_init_use_dist=self.repflow_args.edge_init_use_dist, use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 1486ee358a..67d3642771 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -136,6 +136,9 @@ class DescrptBlockRepflows(DescriptorBlock): smooth_edge_update : bool, optional Whether to make edge update smooth. If True, the edge update from angle message will not use self as padding. + edge_init_use_dist : bool, optional + Whether to use direct distance r to initialize the edge features instead of 1/r. + Note that when using this option, the activation function will not be used when initializing edge features. use_exp_switch : bool, optional Whether to use an exponential switch function instead of a polynomial one in the neighbor update. The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance @@ -206,6 +209,7 @@ def __init__( precision: str = "float64", fix_stat_std: float = 0.3, smooth_edge_update: bool = False, + edge_init_use_dist: bool = False, use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, @@ -241,6 +245,7 @@ def __init__( self.a_compress_use_split = a_compress_use_split self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update + self.edge_init_use_dist = edge_init_use_dist self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor @@ -483,6 +488,10 @@ def forward( # get edge and angle embedding input # nb x nloc x nnei x 1, nb x nloc x nnei x 3 edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) + if self.edge_init_use_dist: + # nb x nloc x nnei x 1 + edge_input = torch.linalg.norm(diff, dim=-1, keepdim=True) + # nf x nloc x a_nnei x 3 normalized_diff_i = a_diff / ( torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6 @@ -519,7 +528,10 @@ def forward( ) # get edge and angle embedding # nb x nloc x nnei x e_dim [OR] n_edge x e_dim - edge_ebd = self.act(self.edge_embd(edge_input)) + if not self.edge_init_use_dist: + edge_ebd = self.act(self.edge_embd(edge_input)) + else: + edge_ebd = self.edge_embd(edge_input) # nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim angle_ebd = self.angle_embd(angle_input) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 6e9663592f..17708a19ff 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1497,6 +1497,10 @@ def dpa3_repflow_args(): "Whether to make edge update smooth. " "If True, the edge update from angle message will not use self as padding." ) + doc_edge_init_use_dist = ( + "Whether to use direct distance r to initialize the edge features instead of 1/r. " + "Note that when using this option, the activation function will not be used when initializing edge features." + ) doc_use_exp_switch = ( "Whether to use an exponential switch function instead of a polynomial one in the neighbor update. " "The exponential switch function ensures neighbor contributions smoothly diminish as the interatomic distance " @@ -1620,6 +1624,14 @@ def dpa3_repflow_args(): default=False, # For compatability. This will be True in the future doc=doc_smooth_edge_update, ), + Argument( + "edge_init_use_dist", + bool, + optional=True, + default=False, + alias=["edge_use_dist"], + doc=doc_edge_init_use_dist, + ), Argument( "use_exp_switch", bool, diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index 2647da52b3..43059f54ba 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -65,6 +65,7 @@ (1, 2), # a_compress_e_rate (True,), # a_compress_use_split (True, False), # optim_update + (True, False), # edge_init_use_dist (True, False), # use_exp_switch (True, False), # use_dynamic_sel (0.3, 0.0), # fix_stat_std @@ -82,6 +83,7 @@ def data(self) -> dict: a_compress_e_rate, a_compress_use_split, optim_update, + edge_init_use_dist, use_exp_switch, use_dynamic_sel, fix_stat_std, @@ -107,6 +109,7 @@ def data(self) -> dict: "a_compress_e_rate": a_compress_e_rate, "a_compress_use_split": a_compress_use_split, "optim_update": optim_update, + "edge_init_use_dist": edge_init_use_dist, "use_exp_switch": use_exp_switch, "use_dynamic_sel": use_dynamic_sel, "smooth_edge_update": True, @@ -137,6 +140,7 @@ def skip_pt(self) -> bool: a_compress_e_rate, a_compress_use_split, optim_update, + edge_init_use_dist, use_exp_switch, use_dynamic_sel, fix_stat_std, @@ -155,6 +159,7 @@ def skip_pd(self) -> bool: a_compress_e_rate, a_compress_use_split, optim_update, + edge_init_use_dist, use_exp_switch, use_dynamic_sel, fix_stat_std, @@ -164,6 +169,7 @@ def skip_pd(self) -> bool: return ( not INSTALLED_PD or precision == "bfloat16" + or edge_init_use_dist or use_exp_switch or use_dynamic_sel ) # not supported yet @@ -178,6 +184,7 @@ def skip_dp(self) -> bool: a_compress_e_rate, a_compress_use_split, optim_update, + edge_init_use_dist, use_exp_switch, use_dynamic_sel, fix_stat_std, @@ -196,6 +203,7 @@ def skip_tf(self) -> bool: a_compress_e_rate, a_compress_use_split, optim_update, + edge_init_use_dist, use_exp_switch, use_dynamic_sel, fix_stat_std, @@ -256,6 +264,7 @@ def setUp(self) -> None: a_compress_e_rate, a_compress_use_split, optim_update, + edge_init_use_dist, use_exp_switch, use_dynamic_sel, fix_stat_std, @@ -337,6 +346,7 @@ def rtol(self) -> float: a_compress_e_rate, a_compress_use_split, optim_update, + edge_init_use_dist, use_exp_switch, use_dynamic_sel, fix_stat_std, @@ -361,6 +371,7 @@ def atol(self) -> float: a_compress_e_rate, a_compress_use_split, optim_update, + edge_init_use_dist, use_exp_switch, use_dynamic_sel, fix_stat_std, diff --git a/source/tests/universal/dpmodel/descriptor/test_descriptor.py b/source/tests/universal/dpmodel/descriptor/test_descriptor.py index 08708c5924..499729fff2 100644 --- a/source/tests/universal/dpmodel/descriptor/test_descriptor.py +++ b/source/tests/universal/dpmodel/descriptor/test_descriptor.py @@ -482,6 +482,7 @@ def DescriptorParamDPA3( a_compress_use_split=False, optim_update=True, smooth_edge_update=False, + edge_init_use_dist=False, use_exp_switch=False, fix_stat_std=0.3, use_dynamic_sel=False, @@ -511,6 +512,7 @@ def DescriptorParamDPA3( "optim_update": optim_update, "use_exp_switch": use_exp_switch, "smooth_edge_update": smooth_edge_update, + "edge_init_use_dist": edge_init_use_dist, "fix_stat_std": fix_stat_std, "n_multi_edge_message": n_multi_edge_message, "axis_neuron": 2, @@ -549,9 +551,10 @@ def DescriptorParamDPA3( "a_compress_use_split": (True,), "optim_update": (True, False), "smooth_edge_update": (True,), + "edge_init_use_dist": (True, False), "use_exp_switch": (True, False), "fix_stat_std": (0.3,), - "n_multi_edge_message": (1, 2), + "n_multi_edge_message": (1,), "use_dynamic_sel": (True, False), "env_protection": (0.0, 1e-8), "precision": ("float64",),