We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent dda7ae4 commit fe57788Copy full SHA for fe57788
1 file changed
profold2/model/alphafold2.py
@@ -159,7 +159,7 @@ def __init__(
159
accept_frame_update=accept_frame_update,
160
attn_dropout=attn_dropout,
161
ff_dropout=ff_dropout
162
- )
+ ) if evoformer_depth > 0 else None
163
164
# msa to single activations
165
self.to_single_repr = nn.Linear(dim_msa, dim_single)
@@ -234,14 +234,13 @@ def forward(
234
quaternions = torch.tensor([1., 0., 0., 0.], device=device)
235
quaternions = torch.tile(quaternions, b + (n, 1))
236
translations = torch.zeros(b + (n, 3), device=device)
237
+ t = (quaternions, translations)
238
239
# trunk
- x, m, t = self.evoformer(
240
- x, m, (quaternions, translations),
241
- mask=x_mask,
242
- msa_mask=msa_mask,
243
- shard_size=shard_size
244
+ if exists(self.evoformer):
+ x, m, t = self.evoformer(
+ x, m, t, mask=x_mask, msa_mask=msa_mask, shard_size=shard_size
+ )
245
246
s = self.to_single_repr(m[..., 0, :, :])
247
0 commit comments