@@ -210,8 +210,12 @@ def __init__(
210210 self .n_a_compress_dim = n_dim
211211 else :
212212 # angle + a_dim/c + a_dim/2c * 2 * e_rate
213- self .angle_dim += (1 + self .a_compress_e_rate ) * (self .a_dim // self .a_compress_rate )
214- self .e_a_compress_dim = self .a_dim // (2 * self .a_compress_rate ) * self .a_compress_e_rate
213+ self .angle_dim += (1 + self .a_compress_e_rate ) * (
214+ self .a_dim // self .a_compress_rate
215+ )
216+ self .e_a_compress_dim = (
217+ self .a_dim // (2 * self .a_compress_rate ) * self .a_compress_e_rate
218+ )
215219 self .n_a_compress_dim = self .a_dim // self .a_compress_rate
216220 if not self .a_compress_use_split :
217221 self .a_compress_n_linear = MLPLayer (
@@ -327,7 +331,10 @@ def _cal_hg(
327331 # nb x nloc x nnei x e_dim
328332 edge_ebd = _apply_nlist_mask (edge_ebd , nlist_mask )
329333 edge_ebd = _apply_switch (edge_ebd , sw )
330- invnnei = torch .rsqrt (float (nnei ) * torch .ones ((nb , nloc , 1 , 1 ), dtype = edge_ebd .dtype , device = edge_ebd .device ))
334+ invnnei = torch .rsqrt (
335+ float (nnei )
336+ * torch .ones ((nb , nloc , 1 , 1 ), dtype = edge_ebd .dtype , device = edge_ebd .device )
337+ )
331338 # nb x nloc x 3 x e_dim
332339 h2g2 = torch .matmul (torch .transpose (h2 , - 1 , - 2 ), edge_ebd ) * invnnei
333340 return h2g2
@@ -375,9 +382,16 @@ def _cal_hg_dynamic(
375382 # n_edge x e_dim
376383 flat_edge_ebd = flat_edge_ebd * flat_sw .unsqueeze (- 1 )
377384 # n_edge x 3 x e_dim
378- flat_h2g2 = (flat_h2 .unsqueeze (- 1 ) * flat_edge_ebd .unsqueeze (- 2 )).reshape (- 1 , 3 * e_dim )
385+ flat_h2g2 = (flat_h2 .unsqueeze (- 1 ) * flat_edge_ebd .unsqueeze (- 2 )).reshape (
386+ - 1 , 3 * e_dim
387+ )
379388 # nf x nloc x 3 x e_dim
380- h2g2 = aggregate (flat_h2g2 , owner , average = False , num_owner = num_owner ).reshape (nb , nloc , 3 , e_dim ) * scale_factor
389+ h2g2 = (
390+ aggregate (flat_h2g2 , owner , average = False , num_owner = num_owner ).reshape (
391+ nb , nloc , 3 , e_dim
392+ )
393+ * scale_factor
394+ )
381395 return h2g2
382396
383397 @staticmethod
@@ -530,7 +544,9 @@ def optim_angle_update(
530544 node_dim = node_ebd .shape [- 1 ]
531545 edge_dim = edge_ebd .shape [- 1 ]
532546 # angle_dim, node_dim, edge_dim, edge_dim
533- sub_angle , sub_node , sub_edge_ik , sub_edge_ij = torch .split (matrix , [angle_dim , node_dim , edge_dim , edge_dim ])
547+ sub_angle , sub_node , sub_edge_ik , sub_edge_ij = torch .split (
548+ matrix , [angle_dim , node_dim , edge_dim , edge_dim ]
549+ )
534550
535551 # nf * nloc * a_sel * a_sel * angle_dim
536552 sub_angle_update = torch .matmul (angle_ebd , sub_angle )
@@ -541,7 +557,11 @@ def optim_angle_update(
541557 sub_edge_update_ij = torch .matmul (edge_ebd , sub_edge_ij )
542558
543559 result_update = (
544- bias + sub_node_update .unsqueeze (2 ).unsqueeze (3 ) + sub_edge_update_ik .unsqueeze (2 ) + sub_edge_update_ij .unsqueeze (3 ) + sub_angle_update
560+ bias
561+ + sub_node_update .unsqueeze (2 ).unsqueeze (3 )
562+ + sub_edge_update_ik .unsqueeze (2 )
563+ + sub_edge_update_ij .unsqueeze (3 )
564+ + sub_angle_update
545565 )
546566 return result_update
547567
@@ -565,15 +585,19 @@ def optim_angle_update_dynamic(
565585 edge_dim = flat_edge_ebd .shape [- 1 ]
566586 angle_dim = flat_angle_ebd .shape [- 1 ]
567587 # angle_dim, node_dim, edge_dim, edge_dim
568- sub_angle , sub_node , sub_edge_ik , sub_edge_ij = torch .split (matrix , [angle_dim , node_dim , edge_dim , edge_dim ])
588+ sub_angle , sub_node , sub_edge_ik , sub_edge_ij = torch .split (
589+ matrix , [angle_dim , node_dim , edge_dim , edge_dim ]
590+ )
569591
570592 # n_angle * angle_dim
571593 sub_angle_update = torch .matmul (flat_angle_ebd , sub_angle )
572594
573595 # nf * nloc * angle_dim
574596 sub_node_update = torch .matmul (node_ebd , sub_node )
575597 # n_angle * angle_dim
576- sub_node_update = torch .index_select (sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]), 0 , n2a_index )
598+ sub_node_update = torch .index_select (
599+ sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]), 0 , n2a_index
600+ )
577601
578602 # n_edge * angle_dim
579603 sub_edge_update_ik = torch .matmul (flat_edge_ebd , sub_edge_ik )
@@ -582,7 +606,13 @@ def optim_angle_update_dynamic(
582606 sub_edge_update_ik = torch .index_select (sub_edge_update_ik , 0 , eik2a_index )
583607 sub_edge_update_ij = torch .index_select (sub_edge_update_ij , 0 , eij2a_index )
584608
585- result_update = bias + sub_node_update + sub_edge_update_ik + sub_edge_update_ij + sub_angle_update
609+ result_update = (
610+ bias
611+ + sub_node_update
612+ + sub_edge_update_ik
613+ + sub_edge_update_ij
614+ + sub_angle_update
615+ )
586616 return result_update
587617
588618 def optim_edge_update (
@@ -615,7 +645,9 @@ def optim_edge_update(
615645 # nf * nloc * nnei * node/edge_dim
616646 sub_edge_update = torch .matmul (edge_ebd , edge )
617647
618- result_update = bias + sub_node_update .unsqueeze (2 ) + sub_edge_update + sub_node_ext_update
648+ result_update = (
649+ bias + sub_node_update .unsqueeze (2 ) + sub_edge_update + sub_node_ext_update
650+ )
619651 return result_update
620652
621653 def optim_edge_update_dynamic (
@@ -643,7 +675,9 @@ def optim_edge_update_dynamic(
643675 # nf * nloc * node/edge_dim
644676 sub_node_update = torch .matmul (node_ebd , node )
645677 # n_edge * node/edge_dim
646- sub_node_update = torch .index_select (sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]), 0 , n2e_index )
678+ sub_node_update = torch .index_select (
679+ sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]), 0 , n2e_index
680+ )
647681
648682 # nf * nall * node/edge_dim
649683 sub_node_ext_update = torch .matmul (node_ebd_ext , node_ext )
@@ -742,7 +776,9 @@ def forward(
742776 nei_node_ebd = (
743777 _make_nei_g1 (node_ebd_ext , nlist )
744778 if not self .use_dynamic_sel
745- else torch .index_select (node_ebd_ext .reshape (- 1 , self .n_dim ), 0 , n_ext2e_index )
779+ else torch .index_select (
780+ node_ebd_ext .reshape (- 1 , self .n_dim ), 0 , n_ext2e_index
781+ )
746782 )
747783
748784 n_update_list : list [torch .Tensor ] = [node_ebd ]
@@ -815,7 +851,9 @@ def forward(
815851 # n_edge x (n_dim * 2 + e_dim)
816852 edge_info = torch .cat (
817853 [
818- torch .index_select (node_ebd .reshape (- 1 , self .n_dim ), 0 , n2e_index ),
854+ torch .index_select (
855+ node_ebd .reshape (- 1 , self .n_dim ), 0 , n2e_index
856+ ),
819857 nei_node_ebd ,
820858 edge_ebd ,
821859 ],
@@ -828,7 +866,9 @@ def forward(
828866 # nb x nloc x nnei x (h * n_dim)
829867 if not self .optim_update :
830868 assert edge_info is not None
831- node_edge_update = self .act (self .node_edge_linear (edge_info )) * sw .unsqueeze (- 1 )
869+ node_edge_update = self .act (
870+ self .node_edge_linear (edge_info )
871+ ) * sw .unsqueeze (- 1 )
832872 else :
833873 node_edge_update = self .act (
834874 self .optim_edge_update (
@@ -864,7 +904,9 @@ def forward(
864904
865905 if self .n_multi_edge_message > 1 :
866906 # nb x nloc x h x n_dim
867- node_edge_update_mul_head = node_edge_update .view (nb , nloc , self .n_multi_edge_message , self .n_dim )
907+ node_edge_update_mul_head = node_edge_update .view (
908+ nb , nloc , self .n_multi_edge_message , self .n_dim
909+ )
868910 for head_index in range (self .n_multi_edge_message ):
869911 n_update_list .append (node_edge_update_mul_head [..., head_index , :])
870912 else :
@@ -920,7 +962,9 @@ def forward(
920962 # nb x nloc x a_nnei x e_dim
921963 edge_ebd_for_angle = edge_ebd_for_angle [..., : self .a_sel , :]
922964 # nb x nloc x a_nnei x e_dim
923- edge_ebd_for_angle = torch .where (a_nlist_mask .unsqueeze (- 1 ), edge_ebd_for_angle , 0.0 )
965+ edge_ebd_for_angle = torch .where (
966+ a_nlist_mask .unsqueeze (- 1 ), edge_ebd_for_angle , 0.0
967+ )
924968 if not self .optim_update :
925969 # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim
926970 node_for_angle_info = (
@@ -938,18 +982,24 @@ def forward(
938982
939983 # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim
940984 edge_for_angle_k = (
941- torch .tile (edge_ebd_for_angle .unsqueeze (2 ), (1 , 1 , self .a_sel , 1 , 1 ))
985+ torch .tile (
986+ edge_ebd_for_angle .unsqueeze (2 ), (1 , 1 , self .a_sel , 1 , 1 )
987+ )
942988 if not self .use_dynamic_sel
943989 else torch .index_select (edge_ebd_for_angle , 0 , eik2a_index )
944990 )
945991 # nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim
946992 edge_for_angle_j = (
947- torch .tile (edge_ebd_for_angle .unsqueeze (3 ), (1 , 1 , 1 , self .a_sel , 1 ))
993+ torch .tile (
994+ edge_ebd_for_angle .unsqueeze (3 ), (1 , 1 , 1 , self .a_sel , 1 )
995+ )
948996 if not self .use_dynamic_sel
949997 else torch .index_select (edge_ebd_for_angle , 0 , eij2a_index )
950998 )
951999 # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim)
952- edge_for_angle_info = torch .cat ([edge_for_angle_k , edge_for_angle_j ], dim = - 1 )
1000+ edge_for_angle_info = torch .cat (
1001+ [edge_for_angle_k , edge_for_angle_j ], dim = - 1
1002+ )
9531003 angle_info_list = [angle_ebd ]
9541004 angle_info_list .append (node_for_angle_info )
9551005 angle_info_list .append (edge_for_angle_info )
@@ -987,9 +1037,15 @@ def forward(
9871037
9881038 if not self .use_dynamic_sel :
9891039 # nb x nloc x a_nnei x a_nnei x e_dim
990- weighted_edge_angle_update = a_sw .unsqueeze (- 1 ).unsqueeze (- 1 ) * a_sw .unsqueeze (- 2 ).unsqueeze (- 1 ) * edge_angle_update
1040+ weighted_edge_angle_update = (
1041+ a_sw .unsqueeze (- 1 ).unsqueeze (- 1 )
1042+ * a_sw .unsqueeze (- 2 ).unsqueeze (- 1 )
1043+ * edge_angle_update
1044+ )
9911045 # nb x nloc x a_nnei x e_dim
992- reduced_edge_angle_update = torch .sum (weighted_edge_angle_update , dim = - 2 ) / (self .a_sel ** 0.5 )
1046+ reduced_edge_angle_update = torch .sum (
1047+ weighted_edge_angle_update , dim = - 2
1048+ ) / (self .a_sel ** 0.5 )
9931049 # nb x nloc x nnei x e_dim
9941050 padding_edge_angle_update = torch .concat (
9951051 [
@@ -1017,7 +1073,9 @@ def forward(
10171073 # will be deprecated in the future
10181074 # not support dynamic index, will pass anyway
10191075 if self .use_dynamic_sel :
1020- raise NotImplementedError ("smooth_edge_update must be True when use_dynamic_sel is True!" )
1076+ raise NotImplementedError (
1077+ "smooth_edge_update must be True when use_dynamic_sel is True!"
1078+ )
10211079 full_mask = torch .concat (
10221080 [
10231081 a_nlist_mask ,
@@ -1029,8 +1087,12 @@ def forward(
10291087 ],
10301088 dim = - 1 ,
10311089 )
1032- padding_edge_angle_update = torch .where (full_mask .unsqueeze (- 1 ), padding_edge_angle_update , edge_ebd )
1033- e_update_list .append (self .act (self .edge_angle_linear2 (padding_edge_angle_update )))
1090+ padding_edge_angle_update = torch .where (
1091+ full_mask .unsqueeze (- 1 ), padding_edge_angle_update , edge_ebd
1092+ )
1093+ e_update_list .append (
1094+ self .act (self .edge_angle_linear2 (padding_edge_angle_update ))
1095+ )
10341096 # update edge_ebd
10351097 e_updated = self .list_update (e_update_list , "edge" )
10361098
@@ -1088,7 +1150,9 @@ def list_update_res_incr(self, update_list: list[torch.Tensor]) -> torch.Tensor:
10881150 return uu
10891151
10901152 @torch .jit .export
1091- def list_update_res_residual (self , update_list : list [torch .Tensor ], update_name : str = "node" ) -> torch .Tensor :
1153+ def list_update_res_residual (
1154+ self , update_list : list [torch .Tensor ], update_name : str = "node"
1155+ ) -> torch .Tensor :
10921156 nitem = len (update_list )
10931157 uu = update_list [0 ]
10941158 # make jit happy
@@ -1106,7 +1170,9 @@ def list_update_res_residual(self, update_list: list[torch.Tensor], update_name:
11061170 return uu
11071171
11081172 @torch .jit .export
1109- def list_update (self , update_list : list [torch .Tensor ], update_name : str = "node" ) -> torch .Tensor :
1173+ def list_update (
1174+ self , update_list : list [torch .Tensor ], update_name : str = "node"
1175+ ) -> torch .Tensor :
11101176 if self .update_style == "res_avg" :
11111177 return self .list_update_res_avg (update_list )
11121178 elif self .update_style == "res_incr" :
0 commit comments