@@ -586,7 +586,7 @@ def optim_angle_update_dynamic(
586586 sub_node_update = paddle .matmul (node_ebd , sub_node )
587587 # n_angle * angle_dim
588588 sub_node_update = paddle .index_select (
589- sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]), n2a_index , 0
589+ sub_node_update .reshape (( nf * nloc , sub_node_update .shape [- 1 ]) ), n2a_index , 0
590590 )
591591
592592 # n_edge * angle_dim
@@ -666,7 +666,7 @@ def optim_edge_update_dynamic(
666666 sub_node_update = paddle .matmul (node_ebd , node )
667667 # n_edge * node/edge_dim
668668 sub_node_update = paddle .index_select (
669- sub_node_update .reshape (nf * nloc , sub_node_update .shape [- 1 ]),
669+ sub_node_update .reshape (( nf * nloc , sub_node_update .shape [- 1 ]) ),
670670 n2e_index ,
671671 0 ,
672672 )
@@ -675,7 +675,7 @@ def optim_edge_update_dynamic(
675675 sub_node_ext_update = paddle .matmul (node_ebd_ext , node_ext )
676676 # n_edge * node/edge_dim
677677 sub_node_ext_update = paddle .index_select (
678- sub_node_ext_update .reshape (nf * nall , sub_node_update .shape [- 1 ]),
678+ sub_node_ext_update .reshape (( nf * nall , sub_node_update .shape [- 1 ]) ),
679679 n_ext2e_index ,
680680 0 ,
681681 )
@@ -746,7 +746,7 @@ def forward(
746746 a_updated : nf x nloc x a_nnei x a_nnei x a_dim
747747 Updated angle embedding.
748748 """
749- nb , nloc , nnei , _ = edge_ebd .shape
749+ nb , nloc , nnei = nlist .shape
750750 nall = node_ebd_ext .shape [1 ]
751751 node_ebd = node_ebd_ext [:, :nloc , :]
752752 n_edge = int (nlist_mask .sum ().item ())
@@ -896,7 +896,7 @@ def forward(
896896 n2e_index ,
897897 average = False ,
898898 num_owner = nb * nloc ,
899- ).reshape (nb , nloc , node_edge_update .shape [- 1 ])
899+ ).reshape (( nb , nloc , node_edge_update .shape [- 1 ]) )
900900 / self .dynamic_e_sel
901901 )
902902 )
0 commit comments