Skip to content

Commit b854fc2

Browse files
committed
fix(pt): unsmooth update of edge update in DPA3
1 parent 0917d4e commit b854fc2

7 files changed

Lines changed: 133 additions & 18 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def __init__(
5454
auto_batchsize: int = 0,
5555
optim_update: bool = True,
5656
no_sym: bool = False,
57+
smooth_angle_init: bool = False,
58+
angle_init_use_sin: bool = False,
59+
smooth_edge_update: bool = False,
5760
) -> None:
5861
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
5962
@@ -153,6 +156,9 @@ def __init__(
153156
self.auto_batchsize = auto_batchsize
154157
self.optim_update = optim_update
155158
self.no_sym = no_sym
159+
self.smooth_angle_init = smooth_angle_init
160+
self.angle_init_use_sin = angle_init_use_sin
161+
self.smooth_edge_update = smooth_edge_update
156162

157163
def __getitem__(self, key):
158164
if hasattr(self, key):

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ def init_subclass_params(sub_data, sub_class):
192192
optim_update=self.repflow_args.optim_update,
193193
skip_stat=self.repflow_args.skip_stat,
194194
no_sym=self.repflow_args.no_sym,
195+
smooth_angle_init=self.repflow_args.smooth_angle_init,
196+
angle_init_use_sin=self.repflow_args.angle_init_use_sin,
197+
smooth_edge_update=self.repflow_args.smooth_edge_update,
195198
exclude_types=exclude_types,
196199
env_protection=env_protection,
197200
precision=precision,

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
bn_moment: float = 0.1,
7979
optim_update: bool = True,
8080
no_sym: bool = False,
81+
smooth_edge_update: bool = False,
8182
activation_function: str = "silu",
8283
update_style: str = "res_residual",
8384
update_residual: float = 0.1,
@@ -145,6 +146,7 @@ def __init__(
145146
self.n_update_has_a_first_sum = n_update_has_a_first_sum
146147
self.optim_update = optim_update
147148
self.no_sym = no_sym
149+
self.smooth_edge_update = smooth_edge_update
148150

149151
assert update_residual_init in [
150152
"norm",
@@ -1146,20 +1148,22 @@ def forward(
11461148
],
11471149
dim=2,
11481150
)
1149-
full_mask = torch.concat(
1150-
[
1151-
a_nlist_mask,
1152-
torch.zeros(
1153-
[nb, nloc, self.nnei - self.a_sel],
1154-
dtype=a_nlist_mask.dtype,
1155-
device=a_nlist_mask.device,
1156-
),
1157-
],
1158-
dim=-1,
1159-
)
1160-
padding_edge_angle_update = torch.where(
1161-
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
1162-
)
1151+
if not self.smooth_edge_update:
1152+
# will be deprecated in the future
1153+
full_mask = torch.concat(
1154+
[
1155+
a_nlist_mask,
1156+
torch.zeros(
1157+
[nb, nloc, self.nnei - self.a_sel],
1158+
dtype=a_nlist_mask.dtype,
1159+
device=a_nlist_mask.device,
1160+
),
1161+
],
1162+
dim=-1,
1163+
)
1164+
padding_edge_angle_update = torch.where(
1165+
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
1166+
)
11631167
e_update_list.append(
11641168
self.act(self.edge_angle_linear2(padding_edge_angle_update))
11651169
)

deepmd/pt/model/descriptor/repflows.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ def __init__(
112112
precision: str = "float64",
113113
skip_stat: bool = True,
114114
no_sym: bool = False,
115+
smooth_angle_init: bool = False,
116+
angle_init_use_sin: bool = False,
117+
smooth_edge_update: bool = False,
115118
pre_ln: bool = False,
116119
only_e_ln: bool = False,
117120
pre_bn: bool = False,
@@ -239,6 +242,9 @@ def __init__(
239242
self.auto_batchsize = auto_batchsize
240243
self.optim_update = optim_update
241244
self.no_sym = no_sym
245+
self.smooth_angle_init = smooth_angle_init
246+
self.angle_init_use_sin = angle_init_use_sin
247+
self.smooth_edge_update = smooth_edge_update
242248

243249
self.n_dim = n_dim
244250
self.e_dim = e_dim
@@ -299,7 +305,11 @@ def __init__(
299305
1, self.e_dim, precision=precision, seed=child_seed(seed, 0)
300306
)
301307
self.angle_embd = MLPLayer(
302-
1, self.a_dim, precision=precision, bias=False, seed=child_seed(seed, 1)
308+
1 if not self.angle_init_use_sin else 2,
309+
self.a_dim,
310+
precision=precision,
311+
bias=False,
312+
seed=child_seed(seed, 1),
303313
)
304314
self.has_h1 = self.update_n_has_h1 or self.update_e_has_h1
305315
if self.has_h1:
@@ -452,6 +462,7 @@ def __init__(
452462
bn_moment=self.bn_moment,
453463
optim_update=self.optim_update,
454464
no_sym=self.no_sym,
465+
smooth_edge_update=self.smooth_edge_update,
455466
seed=child_seed(child_seed(seed, 1), ii),
456467
)
457468
)
@@ -632,10 +643,21 @@ def forward(
632643
# nf x nloc x a_nnei x a_nnei
633644
# 1 - 1e-6 for torch.acos stability
634645
cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6)
635-
# nf x nloc x a_nnei x a_nnei x 1
636-
cosine_ij = cosine_ij.unsqueeze(-1) / (torch.pi**0.5)
646+
sine_ij = torch.sqrt(1 - cosine_ij**2)
647+
if self.smooth_angle_init:
648+
cosine_ij = cosine_ij * a_sw.unsqueeze(-1) * a_sw.unsqueeze(-2)
649+
sine_ij = sine_ij * a_sw.unsqueeze(-1) * a_sw.unsqueeze(-2)
650+
651+
if not self.angle_init_use_sin:
652+
# nf x nloc x a_nnei x a_nnei x 1
653+
angle_input = cosine_ij.unsqueeze(-1) / (torch.pi**0.5)
654+
else:
655+
angle_input = torch.cat(
656+
[cosine_ij.unsqueeze(-1), sine_ij.unsqueeze(-1)], dim=-1
657+
) / (torch.pi**0.5)
658+
637659
# nf x nloc x a_nnei x a_nnei x a_dim
638-
angle_ebd = self.angle_embd(cosine_ij).reshape(
660+
angle_ebd = self.angle_embd(angle_input).reshape(
639661
nframes, nloc, self.a_sel, self.a_sel, self.a_dim
640662
)
641663
if self.has_h1:

deepmd/utils/argcheck.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,6 +1727,24 @@ def dpa3_repflow_args():
17271727
optional=True,
17281728
default=False,
17291729
),
1730+
Argument(
1731+
"smooth_angle_init",
1732+
bool,
1733+
optional=True,
1734+
default=False,
1735+
),
1736+
Argument(
1737+
"angle_init_use_sin",
1738+
bool,
1739+
optional=True,
1740+
default=False,
1741+
),
1742+
Argument(
1743+
"smooth_edge_update",
1744+
bool,
1745+
optional=True,
1746+
default=False,
1747+
),
17301748
]
17311749

17321750

source/tests/pt/model/test_permutation.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,46 @@
167167
},
168168
}
169169

170+
model_dpa3 = {
171+
"type_map": ["O", "H", "B"],
172+
"descriptor": {
173+
"type": "dpa3",
174+
"repflow": {
175+
"n_dim": 30,
176+
"e_dim": 20,
177+
"a_dim": 10,
178+
"nlayers": 3,
179+
"e_rcut": 6.0,
180+
"e_rcut_smth": 3.5,
181+
"e_sel": 20,
182+
"a_rcut": 4.0,
183+
"a_rcut_smth": 3.5,
184+
"a_sel": 12,
185+
"axis_neuron": 4,
186+
"skip_stat": True,
187+
"a_compress_rate": 1,
188+
"a_compress_e_rate": 2,
189+
"a_compress_use_split": True,
190+
"update_angle": True,
191+
"smooth_edge_update": True,
192+
"update_style": "res_residual",
193+
"update_residual": 0.1,
194+
"update_residual_init": "const",
195+
},
196+
"activation_function": "custom_silu:10.0",
197+
"use_tebd_bias": False,
198+
"precision": "float32",
199+
"concat_output_tebd": False,
200+
},
201+
"fitting_net": {
202+
"neuron": [24, 24],
203+
"activation_function": "custom_silu:10.0",
204+
"resnet_dt": True,
205+
"precision": "float32",
206+
"seed": 1,
207+
},
208+
}
209+
170210
model_dpa2tebd = {
171211
"type_map": ["O", "H", "B"],
172212
"descriptor": {

source/tests/pt/model/test_smooth.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
model_dos,
2222
model_dpa1,
2323
model_dpa2,
24+
model_dpa3,
2425
model_hybrid,
2526
model_se_e2_a,
2627
model_spin,
@@ -53,6 +54,12 @@ def test(
5354
0.0,
5455
0.0,
5556
0.0,
57+
6.0 - 0.5 * epsilon,
58+
0.0,
59+
0.0,
60+
0.0,
61+
6.0 - 0.5 * epsilon,
62+
0.0,
5663
4.0 - 0.5 * epsilon,
5764
0.0,
5865
0.0,
@@ -77,11 +84,15 @@ def test(
7784
coord0 = torch.clone(coord)
7885
coord1 = torch.clone(coord)
7986
coord1[1][0] += epsilon
87+
coord1[3][0] += epsilon
8088
coord2 = torch.clone(coord)
8189
coord2[2][1] += epsilon
90+
coord2[4][1] += epsilon
8291
coord3 = torch.clone(coord)
8392
coord3[1][0] += epsilon
93+
coord1[3][0] += epsilon
8494
coord3[2][1] += epsilon
95+
coord2[4][1] += epsilon
8596
test_spin = getattr(self, "test_spin", False)
8697
if not test_spin:
8798
test_keys = ["energy", "force", "virial"]
@@ -226,6 +237,17 @@ def setUp(self) -> None:
226237
self.epsilon, self.aprec = None, None
227238

228239

240+
class TestEnergyModelDPA3(unittest.TestCase, SmoothTest):
241+
def setUp(self) -> None:
242+
model_params = copy.deepcopy(model_dpa3)
243+
self.type_split = True
244+
self.model = get_model(model_params).to(env.DEVICE)
245+
# less degree of smoothness,
246+
# error can be systematically removed by reducing epsilon
247+
self.epsilon = 1e-5
248+
self.aprec = 1e-5
249+
250+
229251
class TestEnergyModelHybrid(unittest.TestCase, SmoothTest):
230252
def setUp(self) -> None:
231253
model_params = copy.deepcopy(model_hybrid)

0 commit comments

Comments
 (0)