Skip to content

Commit f4fb536

Browse files
committed
make angle update smooth
1 parent 250ae65 commit f4fb536

7 files changed

Lines changed: 133 additions & 18 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def __init__(
2525
update_residual_init: str = "const",
2626
skip_stat: bool = False,
2727
optim_update: bool = True,
28+
smooth_angle_init: bool = False,
29+
angle_init_use_sin: bool = False,
30+
smooth_edge_update: bool = False,
2831
) -> None:
2932
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
3033
@@ -102,6 +105,9 @@ def __init__(
102105
self.a_compress_e_rate = a_compress_e_rate
103106
self.a_compress_use_split = a_compress_use_split
104107
self.optim_update = optim_update
108+
self.smooth_angle_init = smooth_angle_init
109+
self.angle_init_use_sin = angle_init_use_sin
110+
self.smooth_edge_update = smooth_edge_update
105111

106112
def __getitem__(self, key):
107113
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def init_subclass_params(sub_data, sub_class):
163163
update_residual_init=self.repflow_args.update_residual_init,
164164
optim_update=self.repflow_args.optim_update,
165165
skip_stat=self.repflow_args.skip_stat,
166+
smooth_angle_init=self.repflow_args.smooth_angle_init,
167+
angle_init_use_sin=self.repflow_args.angle_init_use_sin,
168+
smooth_edge_update=self.repflow_args.smooth_edge_update,
166169
exclude_types=exclude_types,
167170
env_protection=env_protection,
168171
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
axis_neuron: int = 4,
5353
update_angle: bool = True, # angle
5454
optim_update: bool = True,
55+
smooth_edge_update: bool = False,
5556
activation_function: str = "silu",
5657
update_style: str = "res_residual",
5758
update_residual: float = 0.1,
@@ -96,6 +97,7 @@ def __init__(
9697
self.seed = seed
9798
self.prec = PRECISION_DICT[precision]
9899
self.optim_update = optim_update
100+
self.smooth_edge_update = smooth_edge_update
99101

100102
assert update_residual_init in [
101103
"norm",
@@ -718,20 +720,22 @@ def forward(
718720
],
719721
dim=2,
720722
)
721-
full_mask = torch.concat(
722-
[
723-
a_nlist_mask,
724-
torch.zeros(
725-
[nb, nloc, self.nnei - self.a_sel],
726-
dtype=a_nlist_mask.dtype,
727-
device=a_nlist_mask.device,
728-
),
729-
],
730-
dim=-1,
731-
)
732-
padding_edge_angle_update = torch.where(
733-
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
734-
)
723+
if not self.smooth_edge_update:
724+
# will be deprecated in the future
725+
full_mask = torch.concat(
726+
[
727+
a_nlist_mask,
728+
torch.zeros(
729+
[nb, nloc, self.nnei - self.a_sel],
730+
dtype=a_nlist_mask.dtype,
731+
device=a_nlist_mask.device,
732+
),
733+
],
734+
dim=-1,
735+
)
736+
padding_edge_angle_update = torch.where(
737+
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
738+
)
735739
e_update_list.append(
736740
self.act(self.edge_angle_linear2(padding_edge_angle_update))
737741
)

deepmd/pt/model/descriptor/repflows.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def __init__(
100100
env_protection: float = 0.0,
101101
precision: str = "float64",
102102
skip_stat: bool = True,
103+
smooth_angle_init: bool = False,
104+
angle_init_use_sin: bool = False,
105+
smooth_edge_update: bool = False,
103106
optim_update: bool = True,
104107
seed: Optional[Union[int, list[int]]] = None,
105108
) -> None:
@@ -202,6 +205,9 @@ def __init__(
202205
self.skip_stat = skip_stat
203206
self.a_compress_use_split = a_compress_use_split
204207
self.optim_update = optim_update
208+
self.smooth_angle_init = smooth_angle_init
209+
self.angle_init_use_sin = angle_init_use_sin
210+
self.smooth_edge_update = smooth_edge_update
205211

206212
self.n_dim = n_dim
207213
self.e_dim = e_dim
@@ -226,7 +232,11 @@ def __init__(
226232
1, self.e_dim, precision=precision, seed=child_seed(seed, 0)
227233
)
228234
self.angle_embd = MLPLayer(
229-
1, self.a_dim, precision=precision, bias=False, seed=child_seed(seed, 1)
235+
1 if not self.angle_init_use_sin else 2,
236+
self.a_dim,
237+
precision=precision,
238+
bias=False,
239+
seed=child_seed(seed, 1),
230240
)
231241
layers = []
232242
for ii in range(nlayers):
@@ -254,6 +264,7 @@ def __init__(
254264
update_residual_init=self.update_residual_init,
255265
precision=precision,
256266
optim_update=self.optim_update,
267+
smooth_edge_update=self.smooth_edge_update,
257268
seed=child_seed(child_seed(seed, 1), ii),
258269
)
259270
)
@@ -434,10 +445,21 @@ def forward(
434445
# nf x nloc x a_nnei x a_nnei
435446
# 1 - 1e-6 for torch.acos stability
436447
cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6)
437-
# nf x nloc x a_nnei x a_nnei x 1
438-
cosine_ij = cosine_ij.unsqueeze(-1) / (torch.pi**0.5)
448+
sine_ij = torch.sqrt(1 - cosine_ij**2)
449+
if self.smooth_angle_init:
450+
cosine_ij = cosine_ij * a_sw.unsqueeze(-1) * a_sw.unsqueeze(-2)
451+
sine_ij = sine_ij * a_sw.unsqueeze(-1) * a_sw.unsqueeze(-2)
452+
453+
if not self.angle_init_use_sin:
454+
# nf x nloc x a_nnei x a_nnei x 1
455+
angle_input = cosine_ij.unsqueeze(-1) / (torch.pi**0.5)
456+
else:
457+
angle_input = torch.cat(
458+
[cosine_ij.unsqueeze(-1), sine_ij.unsqueeze(-1)], dim=-1
459+
) / (torch.pi**0.5)
460+
439461
# nf x nloc x a_nnei x a_nnei x a_dim
440-
angle_ebd = self.angle_embd(cosine_ij).reshape(
462+
angle_ebd = self.angle_embd(angle_input).reshape(
441463
nframes, nloc, self.a_sel, self.a_sel, self.a_dim
442464
)
443465

deepmd/utils/argcheck.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,24 @@ def dpa3_repflow_args():
15681568
optional=True,
15691569
default=True,
15701570
),
1571+
Argument(
1572+
"smooth_angle_init",
1573+
bool,
1574+
optional=True,
1575+
default=False,
1576+
),
1577+
Argument(
1578+
"angle_init_use_sin",
1579+
bool,
1580+
optional=True,
1581+
default=False,
1582+
),
1583+
Argument(
1584+
"smooth_edge_update",
1585+
bool,
1586+
optional=True,
1587+
default=False,
1588+
),
15711589
]
15721590

15731591

source/tests/pt/model/test_permutation.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,46 @@
167167
},
168168
}
169169

170+
model_dpa3 = {
171+
"type_map": ["O", "H", "B"],
172+
"descriptor": {
173+
"type": "dpa3",
174+
"repflow": {
175+
"n_dim": 30,
176+
"e_dim": 20,
177+
"a_dim": 10,
178+
"nlayers": 3,
179+
"e_rcut": 6.0,
180+
"e_rcut_smth": 3.5,
181+
"e_sel": 20,
182+
"a_rcut": 4.0,
183+
"a_rcut_smth": 3.5,
184+
"a_sel": 12,
185+
"axis_neuron": 4,
186+
"skip_stat": True,
187+
"a_compress_rate": 1,
188+
"a_compress_e_rate": 2,
189+
"a_compress_use_split": True,
190+
"update_angle": True,
191+
"smooth_edge_update": True,
192+
"update_style": "res_residual",
193+
"update_residual": 0.1,
194+
"update_residual_init": "const",
195+
},
196+
"activation_function": "custom_silu:10.0",
197+
"use_tebd_bias": False,
198+
"precision": "float32",
199+
"concat_output_tebd": False,
200+
},
201+
"fitting_net": {
202+
"neuron": [24, 24],
203+
"activation_function": "custom_silu:10.0",
204+
"resnet_dt": True,
205+
"precision": "float32",
206+
"seed": 1,
207+
},
208+
}
209+
170210
model_dpa2tebd = {
171211
"type_map": ["O", "H", "B"],
172212
"descriptor": {

source/tests/pt/model/test_smooth.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
model_dos,
2222
model_dpa1,
2323
model_dpa2,
24+
model_dpa3,
2425
model_hybrid,
2526
model_se_e2_a,
2627
model_spin,
@@ -53,6 +54,12 @@ def test(
5354
0.0,
5455
0.0,
5556
0.0,
57+
6.0 - 0.5 * epsilon,
58+
0.0,
59+
0.0,
60+
0.0,
61+
6.0 - 0.5 * epsilon,
62+
0.0,
5663
4.0 - 0.5 * epsilon,
5764
0.0,
5865
0.0,
@@ -77,11 +84,15 @@ def test(
7784
coord0 = torch.clone(coord)
7885
coord1 = torch.clone(coord)
7986
coord1[1][0] += epsilon
87+
coord1[3][0] += epsilon
8088
coord2 = torch.clone(coord)
8189
coord2[2][1] += epsilon
90+
coord2[4][1] += epsilon
8291
coord3 = torch.clone(coord)
8392
coord3[1][0] += epsilon
93+
coord1[3][0] += epsilon
8494
coord3[2][1] += epsilon
95+
coord2[4][1] += epsilon
8596
test_spin = getattr(self, "test_spin", False)
8697
if not test_spin:
8798
test_keys = ["energy", "force", "virial"]
@@ -226,6 +237,17 @@ def setUp(self) -> None:
226237
self.epsilon, self.aprec = None, None
227238

228239

240+
class TestEnergyModelDPA3(unittest.TestCase, SmoothTest):
241+
def setUp(self) -> None:
242+
model_params = copy.deepcopy(model_dpa3)
243+
self.type_split = True
244+
self.model = get_model(model_params).to(env.DEVICE)
245+
# less degree of smoothness,
246+
# error can be systematically removed by reducing epsilon
247+
self.epsilon = 1e-5
248+
self.aprec = 1e-5
249+
250+
229251
class TestEnergyModelHybrid(unittest.TestCase, SmoothTest):
230252
def setUp(self) -> None:
231253
model_params = copy.deepcopy(model_hybrid)

0 commit comments

Comments
 (0)