@@ -466,7 +466,8 @@ def optim_edge_update(
466466
467467 def forward (
468468 self ,
469- node_ebd_ext : torch .Tensor , # nf x nall x n_dim
469+ node_ebd : torch .Tensor , # nf x nloc x n_dim
470+ node_ebd_ext : Optional [torch .Tensor ], # nf x nall x n_dim
470471 edge_ebd : torch .Tensor , # nf x nloc x nnei x e_dim
471472 h2 : torch .Tensor , # nf x nloc x nnei x 3
472473 angle_ebd : torch .Tensor , # nf x nloc x a_nnei x a_nnei x a_dim
@@ -511,8 +512,6 @@ def forward(
511512 Updated angle embedding.
512513 """
513514 nb , nloc , nnei , _ = edge_ebd .shape
514- nall = node_ebd_ext .shape [1 ]
515- node_ebd = node_ebd_ext [:, :nloc , :]
516515 assert (nb , nloc ) == node_ebd .shape [:2 ]
517516 assert (nb , nloc , nnei ) == h2 .shape [:3 ]
518517 del a_nlist # may be used in the future
@@ -524,8 +523,10 @@ def forward(
524523 # node self mlp
525524 node_self_mlp = self .act (self .node_self_mlp (node_ebd ))
526525 n_update_list .append (node_self_mlp )
527-
528- nei_node_ebd = _make_nei_g1 (node_ebd_ext , nlist )
526+ if node_ebd_ext is not None :
527+ nei_node_ebd = _make_nei_g1 (node_ebd_ext , nlist )
528+ else :
529+ nei_node_ebd = _make_nei_g1 (node_ebd , nlist )
529530
530531 # node sym (grrg + drrd)
531532 node_sym_list : list [torch .Tensor ] = []
0 commit comments