@@ -373,13 +373,12 @@ def forward(
373373 nlist : torch .Tensor ,
374374 extended_coord : torch .Tensor ,
375375 extended_atype : torch .Tensor ,
376- extended_atype_embd : Optional [ torch .Tensor ] = None ,
376+ extended_atype_embd : torch .Tensor ,
377377 mapping : Optional [torch .Tensor ] = None ,
378378 comm_dict : Optional [dict [str , torch .Tensor ]] = None ,
379379 ):
380380 if comm_dict is None :
381381 assert mapping is not None
382- assert extended_atype_embd is not None
383382 nframes , nloc , nnei = nlist .shape
384383 nall = extended_coord .view (nframes , - 1 ).shape [1 ] // 3
385384 atype = extended_atype [:, :nloc ]
@@ -402,14 +401,9 @@ def forward(
402401 # beyond the cutoff sw should be 0.0
403402 sw = sw .masked_fill (~ nlist_mask , 0.0 )
404403
404+ atype_embd = extended_atype_embd
405405 # [nframes, nloc, tebd_dim]
406- if comm_dict is None :
407- assert isinstance (extended_atype_embd , torch .Tensor ) # for jit
408- atype_embd = extended_atype_embd [:, :nloc , :]
409- assert list (atype_embd .shape ) == [nframes , nloc , self .n_dim ]
410- else :
411- atype_embd = extended_atype_embd
412- assert isinstance (atype_embd , torch .Tensor ) # for jit
406+ assert list (atype_embd .shape ) == [nframes , nloc , self .n_dim ]
413407 node_ebd = self .act (atype_embd )
414408 n_dim = node_ebd .shape [- 1 ]
415409 # nb x nloc x nnei x 1, nb x nloc x nnei x 3
0 commit comments