@@ -209,6 +209,7 @@ def __init__(
209209 use_exp_switch : bool = False ,
210210 use_dynamic_sel : bool = False ,
211211 sel_reduce_factor : float = 10.0 ,
212+ use_loc_mapping : bool = True ,
212213 optim_update : bool = True ,
213214 seed : Optional [Union [int , list [int ]]] = None ,
214215 ) -> None :
@@ -416,9 +417,9 @@ def forward(
416417 mapping : Optional [torch .Tensor ] = None ,
417418 comm_dict : Optional [dict [str , torch .Tensor ]] = None ,
418419 ):
419- if comm_dict is None :
420+ parrallel_mode = comm_dict is not None
421+ if not parrallel_mode :
420422 assert mapping is not None
421- assert extended_atype_embd is not None
422423 nframes , nloc , nnei = nlist .shape
423424 nall = extended_coord .view (nframes , - 1 ).shape [1 ] // 3
424425 atype = extended_atype [:, :nloc ]
@@ -470,12 +471,9 @@ def forward(
470471
471472 # get node embedding
472473 # [nframes, nloc, tebd_dim]
473- if comm_dict is None :
474- assert isinstance (extended_atype_embd , torch .Tensor ) # for jit
475- atype_embd = extended_atype_embd [:, :nloc , :]
476- assert list (atype_embd .shape ) == [nframes , nloc , self .n_dim ]
477- else :
478- atype_embd = extended_atype_embd
474+ assert extended_atype_embd is not None
475+ atype_embd = extended_atype_embd [:, :nloc , :]
476+ assert list (atype_embd .shape ) == [nframes , nloc , self .n_dim ]
479477 assert isinstance (atype_embd , torch .Tensor ) # for jit
480478 node_ebd = self .act (atype_embd )
481479 n_dim = node_ebd .shape [- 1 ]
@@ -494,10 +492,19 @@ def forward(
494492 cosine_ij = torch .matmul (normalized_diff_i , normalized_diff_j ) * (1 - 1e-6 )
495493 angle_input = cosine_ij .unsqueeze (- 1 ) / (torch .pi ** 0.5 )
496494
495+ if not parrallel_mode and self .use_loc_mapping :
496+ assert mapping is not None
497+ # convert nlist from nall to nloc index
498+ nlist = torch .gather (
499+ mapping ,
500+ 1 ,
501+ index = nlist .reshape (nframes , - 1 ),
502+ ).reshape (nlist .shape )
497503 if self .use_dynamic_sel :
498504 # get graph index
499505 edge_index , angle_index = get_graph_index (
500- nlist , nlist_mask , a_nlist_mask , nall
506+ nlist , nlist_mask , a_nlist_mask , nall ,
507+ use_loc_mapping = self .use_loc_mapping ,
501508 )
502509 # flat all the tensors
503510 # n_edge x 1
@@ -524,18 +531,23 @@ def forward(
524531 angle_ebd = self .angle_embd (angle_input )
525532
526533 # nb x nall x n_dim
527- if comm_dict is None :
534+ if not parrallel_mode :
528535 assert mapping is not None
529536 mapping = (
530537 mapping .view (nframes , nall ).unsqueeze (- 1 ).expand (- 1 , - 1 , self .n_dim )
531538 )
532539 for idx , ll in enumerate (self .layers ):
533540 # node_ebd: nb x nloc x n_dim
534- # node_ebd_ext: nb x nall x n_dim
535- if comm_dict is None :
541+ # node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parrallel_mode
542+ if not parrallel_mode :
536543 assert mapping is not None
537- node_ebd_ext = torch .gather (node_ebd , 1 , mapping )
544+ node_ebd_ext = (
545+ torch .gather (node_ebd , 1 , mapping )
546+ if not self .use_loc_mapping
547+ else node_ebd
548+ )
538549 else :
550+ assert comm_dict is not None
539551 has_spin = "has_spin" in comm_dict
540552 if not has_spin :
541553 n_padding = nall - nloc
0 commit comments