Skip to content

Commit 82d9b65

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ac0ebb9 commit 82d9b65

File tree

2 files changed

+93
-30
lines changed

2 files changed

+93
-30
lines changed

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 93 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,12 @@ def __init__(
210210
self.n_a_compress_dim = n_dim
211211
else:
212212
# angle + a_dim/c + a_dim/2c * 2 * e_rate
213-
self.angle_dim += (1 + self.a_compress_e_rate) * (self.a_dim // self.a_compress_rate)
214-
self.e_a_compress_dim = self.a_dim // (2 * self.a_compress_rate) * self.a_compress_e_rate
213+
self.angle_dim += (1 + self.a_compress_e_rate) * (
214+
self.a_dim // self.a_compress_rate
215+
)
216+
self.e_a_compress_dim = (
217+
self.a_dim // (2 * self.a_compress_rate) * self.a_compress_e_rate
218+
)
215219
self.n_a_compress_dim = self.a_dim // self.a_compress_rate
216220
if not self.a_compress_use_split:
217221
self.a_compress_n_linear = MLPLayer(
@@ -327,7 +331,10 @@ def _cal_hg(
327331
# nb x nloc x nnei x e_dim
328332
edge_ebd = _apply_nlist_mask(edge_ebd, nlist_mask)
329333
edge_ebd = _apply_switch(edge_ebd, sw)
330-
invnnei = torch.rsqrt(float(nnei) * torch.ones((nb, nloc, 1, 1), dtype=edge_ebd.dtype, device=edge_ebd.device))
334+
invnnei = torch.rsqrt(
335+
float(nnei)
336+
* torch.ones((nb, nloc, 1, 1), dtype=edge_ebd.dtype, device=edge_ebd.device)
337+
)
331338
# nb x nloc x 3 x e_dim
332339
h2g2 = torch.matmul(torch.transpose(h2, -1, -2), edge_ebd) * invnnei
333340
return h2g2
@@ -375,9 +382,16 @@ def _cal_hg_dynamic(
375382
# n_edge x e_dim
376383
flat_edge_ebd = flat_edge_ebd * flat_sw.unsqueeze(-1)
377384
# n_edge x 3 x e_dim
378-
flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape(-1, 3 * e_dim)
385+
flat_h2g2 = (flat_h2.unsqueeze(-1) * flat_edge_ebd.unsqueeze(-2)).reshape(
386+
-1, 3 * e_dim
387+
)
379388
# nf x nloc x 3 x e_dim
380-
h2g2 = aggregate(flat_h2g2, owner, average=False, num_owner=num_owner).reshape(nb, nloc, 3, e_dim) * scale_factor
389+
h2g2 = (
390+
aggregate(flat_h2g2, owner, average=False, num_owner=num_owner).reshape(
391+
nb, nloc, 3, e_dim
392+
)
393+
* scale_factor
394+
)
381395
return h2g2
382396

383397
@staticmethod
@@ -530,7 +544,9 @@ def optim_angle_update(
530544
node_dim = node_ebd.shape[-1]
531545
edge_dim = edge_ebd.shape[-1]
532546
# angle_dim, node_dim, edge_dim, edge_dim
533-
sub_angle, sub_node, sub_edge_ik, sub_edge_ij = torch.split(matrix, [angle_dim, node_dim, edge_dim, edge_dim])
547+
sub_angle, sub_node, sub_edge_ik, sub_edge_ij = torch.split(
548+
matrix, [angle_dim, node_dim, edge_dim, edge_dim]
549+
)
534550

535551
# nf * nloc * a_sel * a_sel * angle_dim
536552
sub_angle_update = torch.matmul(angle_ebd, sub_angle)
@@ -541,7 +557,11 @@ def optim_angle_update(
541557
sub_edge_update_ij = torch.matmul(edge_ebd, sub_edge_ij)
542558

543559
result_update = (
544-
bias + sub_node_update.unsqueeze(2).unsqueeze(3) + sub_edge_update_ik.unsqueeze(2) + sub_edge_update_ij.unsqueeze(3) + sub_angle_update
560+
bias
561+
+ sub_node_update.unsqueeze(2).unsqueeze(3)
562+
+ sub_edge_update_ik.unsqueeze(2)
563+
+ sub_edge_update_ij.unsqueeze(3)
564+
+ sub_angle_update
545565
)
546566
return result_update
547567

@@ -565,15 +585,19 @@ def optim_angle_update_dynamic(
565585
edge_dim = flat_edge_ebd.shape[-1]
566586
angle_dim = flat_angle_ebd.shape[-1]
567587
# angle_dim, node_dim, edge_dim, edge_dim
568-
sub_angle, sub_node, sub_edge_ik, sub_edge_ij = torch.split(matrix, [angle_dim, node_dim, edge_dim, edge_dim])
588+
sub_angle, sub_node, sub_edge_ik, sub_edge_ij = torch.split(
589+
matrix, [angle_dim, node_dim, edge_dim, edge_dim]
590+
)
569591

570592
# n_angle * angle_dim
571593
sub_angle_update = torch.matmul(flat_angle_ebd, sub_angle)
572594

573595
# nf * nloc * angle_dim
574596
sub_node_update = torch.matmul(node_ebd, sub_node)
575597
# n_angle * angle_dim
576-
sub_node_update = torch.index_select(sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), 0, n2a_index)
598+
sub_node_update = torch.index_select(
599+
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), 0, n2a_index
600+
)
577601

578602
# n_edge * angle_dim
579603
sub_edge_update_ik = torch.matmul(flat_edge_ebd, sub_edge_ik)
@@ -582,7 +606,13 @@ def optim_angle_update_dynamic(
582606
sub_edge_update_ik = torch.index_select(sub_edge_update_ik, 0, eik2a_index)
583607
sub_edge_update_ij = torch.index_select(sub_edge_update_ij, 0, eij2a_index)
584608

585-
result_update = bias + sub_node_update + sub_edge_update_ik + sub_edge_update_ij + sub_angle_update
609+
result_update = (
610+
bias
611+
+ sub_node_update
612+
+ sub_edge_update_ik
613+
+ sub_edge_update_ij
614+
+ sub_angle_update
615+
)
586616
return result_update
587617

588618
def optim_edge_update(
@@ -615,7 +645,9 @@ def optim_edge_update(
615645
# nf * nloc * nnei * node/edge_dim
616646
sub_edge_update = torch.matmul(edge_ebd, edge)
617647

618-
result_update = bias + sub_node_update.unsqueeze(2) + sub_edge_update + sub_node_ext_update
648+
result_update = (
649+
bias + sub_node_update.unsqueeze(2) + sub_edge_update + sub_node_ext_update
650+
)
619651
return result_update
620652

621653
def optim_edge_update_dynamic(
@@ -643,7 +675,9 @@ def optim_edge_update_dynamic(
643675
# nf * nloc * node/edge_dim
644676
sub_node_update = torch.matmul(node_ebd, node)
645677
# n_edge * node/edge_dim
646-
sub_node_update = torch.index_select(sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), 0, n2e_index)
678+
sub_node_update = torch.index_select(
679+
sub_node_update.reshape(nf * nloc, sub_node_update.shape[-1]), 0, n2e_index
680+
)
647681

648682
# nf * nall * node/edge_dim
649683
sub_node_ext_update = torch.matmul(node_ebd_ext, node_ext)
@@ -742,7 +776,9 @@ def forward(
742776
nei_node_ebd = (
743777
_make_nei_g1(node_ebd_ext, nlist)
744778
if not self.use_dynamic_sel
745-
else torch.index_select(node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index)
779+
else torch.index_select(
780+
node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index
781+
)
746782
)
747783

748784
n_update_list: list[torch.Tensor] = [node_ebd]
@@ -815,7 +851,9 @@ def forward(
815851
# n_edge x (n_dim * 2 + e_dim)
816852
edge_info = torch.cat(
817853
[
818-
torch.index_select(node_ebd.reshape(-1, self.n_dim), 0, n2e_index),
854+
torch.index_select(
855+
node_ebd.reshape(-1, self.n_dim), 0, n2e_index
856+
),
819857
nei_node_ebd,
820858
edge_ebd,
821859
],
@@ -828,7 +866,9 @@ def forward(
828866
# nb x nloc x nnei x (h * n_dim)
829867
if not self.optim_update:
830868
assert edge_info is not None
831-
node_edge_update = self.act(self.node_edge_linear(edge_info)) * sw.unsqueeze(-1)
869+
node_edge_update = self.act(
870+
self.node_edge_linear(edge_info)
871+
) * sw.unsqueeze(-1)
832872
else:
833873
node_edge_update = self.act(
834874
self.optim_edge_update(
@@ -864,7 +904,9 @@ def forward(
864904

865905
if self.n_multi_edge_message > 1:
866906
# nb x nloc x h x n_dim
867-
node_edge_update_mul_head = node_edge_update.view(nb, nloc, self.n_multi_edge_message, self.n_dim)
907+
node_edge_update_mul_head = node_edge_update.view(
908+
nb, nloc, self.n_multi_edge_message, self.n_dim
909+
)
868910
for head_index in range(self.n_multi_edge_message):
869911
n_update_list.append(node_edge_update_mul_head[..., head_index, :])
870912
else:
@@ -920,7 +962,9 @@ def forward(
920962
# nb x nloc x a_nnei x e_dim
921963
edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :]
922964
# nb x nloc x a_nnei x e_dim
923-
edge_ebd_for_angle = torch.where(a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0)
965+
edge_ebd_for_angle = torch.where(
966+
a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0
967+
)
924968
if not self.optim_update:
925969
# nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim
926970
node_for_angle_info = (
@@ -938,18 +982,24 @@ def forward(
938982

939983
# nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim
940984
edge_for_angle_k = (
941-
torch.tile(edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1))
985+
torch.tile(
986+
edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1)
987+
)
942988
if not self.use_dynamic_sel
943989
else torch.index_select(edge_ebd_for_angle, 0, eik2a_index)
944990
)
945991
# nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim
946992
edge_for_angle_j = (
947-
torch.tile(edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1))
993+
torch.tile(
994+
edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1)
995+
)
948996
if not self.use_dynamic_sel
949997
else torch.index_select(edge_ebd_for_angle, 0, eij2a_index)
950998
)
951999
# nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim)
952-
edge_for_angle_info = torch.cat([edge_for_angle_k, edge_for_angle_j], dim=-1)
1000+
edge_for_angle_info = torch.cat(
1001+
[edge_for_angle_k, edge_for_angle_j], dim=-1
1002+
)
9531003
angle_info_list = [angle_ebd]
9541004
angle_info_list.append(node_for_angle_info)
9551005
angle_info_list.append(edge_for_angle_info)
@@ -987,9 +1037,15 @@ def forward(
9871037

9881038
if not self.use_dynamic_sel:
9891039
# nb x nloc x a_nnei x a_nnei x e_dim
990-
weighted_edge_angle_update = a_sw.unsqueeze(-1).unsqueeze(-1) * a_sw.unsqueeze(-2).unsqueeze(-1) * edge_angle_update
1040+
weighted_edge_angle_update = (
1041+
a_sw.unsqueeze(-1).unsqueeze(-1)
1042+
* a_sw.unsqueeze(-2).unsqueeze(-1)
1043+
* edge_angle_update
1044+
)
9911045
# nb x nloc x a_nnei x e_dim
992-
reduced_edge_angle_update = torch.sum(weighted_edge_angle_update, dim=-2) / (self.a_sel**0.5)
1046+
reduced_edge_angle_update = torch.sum(
1047+
weighted_edge_angle_update, dim=-2
1048+
) / (self.a_sel**0.5)
9931049
# nb x nloc x nnei x e_dim
9941050
padding_edge_angle_update = torch.concat(
9951051
[
@@ -1017,7 +1073,9 @@ def forward(
10171073
# will be deprecated in the future
10181074
# not support dynamic index, will pass anyway
10191075
if self.use_dynamic_sel:
1020-
raise NotImplementedError("smooth_edge_update must be True when use_dynamic_sel is True!")
1076+
raise NotImplementedError(
1077+
"smooth_edge_update must be True when use_dynamic_sel is True!"
1078+
)
10211079
full_mask = torch.concat(
10221080
[
10231081
a_nlist_mask,
@@ -1029,8 +1087,12 @@ def forward(
10291087
],
10301088
dim=-1,
10311089
)
1032-
padding_edge_angle_update = torch.where(full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd)
1033-
e_update_list.append(self.act(self.edge_angle_linear2(padding_edge_angle_update)))
1090+
padding_edge_angle_update = torch.where(
1091+
full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd
1092+
)
1093+
e_update_list.append(
1094+
self.act(self.edge_angle_linear2(padding_edge_angle_update))
1095+
)
10341096
# update edge_ebd
10351097
e_updated = self.list_update(e_update_list, "edge")
10361098

@@ -1088,7 +1150,9 @@ def list_update_res_incr(self, update_list: list[torch.Tensor]) -> torch.Tensor:
10881150
return uu
10891151

10901152
@torch.jit.export
1091-
def list_update_res_residual(self, update_list: list[torch.Tensor], update_name: str = "node") -> torch.Tensor:
1153+
def list_update_res_residual(
1154+
self, update_list: list[torch.Tensor], update_name: str = "node"
1155+
) -> torch.Tensor:
10921156
nitem = len(update_list)
10931157
uu = update_list[0]
10941158
# make jit happy
@@ -1106,7 +1170,9 @@ def list_update_res_residual(self, update_list: list[torch.Tensor], update_name:
11061170
return uu
11071171

11081172
@torch.jit.export
1109-
def list_update(self, update_list: list[torch.Tensor], update_name: str = "node") -> torch.Tensor:
1173+
def list_update(
1174+
self, update_list: list[torch.Tensor], update_name: str = "node"
1175+
) -> torch.Tensor:
11101176
if self.update_style == "res_avg":
11111177
return self.list_update_res_avg(update_list)
11121178
elif self.update_style == "res_incr":

deepmd/pt/model/network/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from typing import (
3-
Optional,
4-
)
52

63
import torch
74

0 commit comments

Comments
 (0)