Skip to content

Commit fe57788

Browse files
committed
feat: disable evoformer from args
1 parent dda7ae4 commit fe57788

1 file changed

Lines changed: 6 additions & 7 deletions

File tree

profold2/model/alphafold2.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(
159159
accept_frame_update=accept_frame_update,
160160
attn_dropout=attn_dropout,
161161
ff_dropout=ff_dropout
162-
)
162+
) if evoformer_depth > 0 else None
163163

164164
# msa to single activations
165165
self.to_single_repr = nn.Linear(dim_msa, dim_single)
@@ -234,14 +234,13 @@ def forward(
234234
quaternions = torch.tensor([1., 0., 0., 0.], device=device)
235235
quaternions = torch.tile(quaternions, b + (n, 1))
236236
translations = torch.zeros(b + (n, 3), device=device)
237+
t = (quaternions, translations)
237238

238239
# trunk
239-
x, m, t = self.evoformer(
240-
x, m, (quaternions, translations),
241-
mask=x_mask,
242-
msa_mask=msa_mask,
243-
shard_size=shard_size
244-
)
240+
if exists(self.evoformer):
241+
x, m, t = self.evoformer(
242+
x, m, t, mask=x_mask, msa_mask=msa_mask, shard_size=shard_size
243+
)
245244

246245
s = self.to_single_repr(m[..., 0, :, :])
247246

0 commit comments

Comments
 (0)