@@ -379,7 +379,6 @@ def forward(
379379 ):
380380 if comm_dict is None :
381381 assert mapping is not None
382- assert extended_atype_embd is not None
383382 nframes , nloc , nnei = nlist .shape
384383 nall = extended_coord .view (nframes , - 1 ).shape [1 ] // 3
385384 atype = extended_atype [:, :nloc ]
@@ -403,13 +402,9 @@ def forward(
403402 sw = sw .masked_fill (~ nlist_mask , 0.0 )
404403
405404 # [nframes, nloc, tebd_dim]
406- if comm_dict is None :
407- assert isinstance (extended_atype_embd , torch .Tensor ) # for jit
408- atype_embd = extended_atype_embd [:, :nloc , :]
409- assert list (atype_embd .shape ) == [nframes , nloc , self .n_dim ]
410- else :
411- atype_embd = extended_atype_embd
412- assert isinstance (atype_embd , torch .Tensor ) # for jit
405+ assert extended_atype_embd is not None
406+ atype_embd = extended_atype_embd [:, :nloc , :]
407+ assert list (atype_embd .shape ) == [nframes , nloc , self .n_dim ]
413408 node_ebd = self .act (atype_embd )
414409 n_dim = node_ebd .shape [- 1 ]
415410 # nb x nloc x nnei x 1, nb x nloc x nnei x 3
0 commit comments