Skip to content

Commit 9dc8ca8

Browse files
authored
Merge pull request #349 from bigict/data
refactor: wrap multi_chain_permutation_alignment
2 parents 4f4327b + fe57788 commit 9dc8ca8

2 files changed

Lines changed: 12 additions & 11 deletions

File tree

profold2/model/alphafold2.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from einops import rearrange
1414

1515
from profold2.common import residue_constants
16-
from profold2.model import commons, functional
16+
from profold2.model import commons, folding, functional
1717
from profold2.model.evoformer import Evoformer
1818
from profold2.model.head import HeaderBuilder
1919
from profold2.utils import env, exists
@@ -159,7 +159,7 @@ def __init__(
159159
accept_frame_update=accept_frame_update,
160160
attn_dropout=attn_dropout,
161161
ff_dropout=ff_dropout
162-
)
162+
) if evoformer_depth > 0 else None
163163

164164
# msa to single activations
165165
self.to_single_repr = nn.Linear(dim_msa, dim_single)
@@ -234,14 +234,13 @@ def forward(
234234
quaternions = torch.tensor([1., 0., 0., 0.], device=device)
235235
quaternions = torch.tile(quaternions, b + (n, 1))
236236
translations = torch.zeros(b + (n, 3), device=device)
237+
t = (quaternions, translations)
237238

238239
# trunk
239-
x, m, t = self.evoformer(
240-
x, m, (quaternions, translations),
241-
mask=x_mask,
242-
msa_mask=msa_mask,
243-
shard_size=shard_size
244-
)
240+
if exists(self.evoformer):
241+
x, m, t = self.evoformer(
242+
x, m, t, mask=x_mask, msa_mask=msa_mask, shard_size=shard_size
243+
)
245244

246245
s = self.to_single_repr(m[..., 0, :, :])
247246

@@ -259,9 +258,7 @@ def forward(
259258
if 'representations' in value:
260259
representations.update(value['representations'])
261260
if 'folding' in ret.headers:
262-
batch = functional.multi_chain_permutation_alignment(
263-
ret.headers['folding'], batch
264-
)
261+
batch = folding.multi_chain_permutation_alignment(ret.headers['folding'], batch)
265262
if self.training and compute_loss:
266263
for name, module, options in self.headers:
267264
if not hasattr(module, 'loss') or name not in ret.headers:

profold2/model/folding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,7 @@ def forward(self, representations, batch):
210210
)
211211

212212
return outputs
213+
214+
215+
def multi_chain_permutation_alignment(value, batch):
216+
return functional.multi_chain_permutation_alignment(value, batch)

0 commit comments

Comments
 (0)