1313from einops import rearrange
1414
1515from profold2 .common import residue_constants
16- from profold2 .model import commons , functional
16+ from profold2 .model import commons , folding , functional
1717from profold2 .model .evoformer import Evoformer
1818from profold2 .model .head import HeaderBuilder
1919from profold2 .utils import env , exists
@@ -159,7 +159,7 @@ def __init__(
159159 accept_frame_update = accept_frame_update ,
160160 attn_dropout = attn_dropout ,
161161 ff_dropout = ff_dropout
162- )
162+ ) if evoformer_depth > 0 else None
163163
164164 # msa to single activations
165165 self .to_single_repr = nn .Linear (dim_msa , dim_single )
@@ -234,14 +234,13 @@ def forward(
234234 quaternions = torch .tensor ([1. , 0. , 0. , 0. ], device = device )
235235 quaternions = torch .tile (quaternions , b + (n , 1 ))
236236 translations = torch .zeros (b + (n , 3 ), device = device )
237+ t = (quaternions , translations )
237238
238239 # trunk
239- x , m , t = self .evoformer (
240- x , m , (quaternions , translations ),
241- mask = x_mask ,
242- msa_mask = msa_mask ,
243- shard_size = shard_size
244- )
240+ if exists (self .evoformer ):
241+ x , m , t = self .evoformer (
242+ x , m , t , mask = x_mask , msa_mask = msa_mask , shard_size = shard_size
243+ )
245244
246245 s = self .to_single_repr (m [..., 0 , :, :])
247246
@@ -259,9 +258,7 @@ def forward(
259258 if 'representations' in value :
260259 representations .update (value ['representations' ])
261260 if 'folding' in ret .headers :
262- batch = functional .multi_chain_permutation_alignment (
263- ret .headers ['folding' ], batch
264- )
261+ batch = folding .multi_chain_permutation_alignment (ret .headers ['folding' ], batch )
265262 if self .training and compute_loss :
266263 for name , module , options in self .headers :
267264 if not hasattr (module , 'loss' ) or name not in ret .headers :
0 commit comments