Skip to content

Commit dda7ae4

Browse files
committed
refactor: wrap multi_chain_permutation_alignment
1 parent 4f4327b commit dda7ae4

2 files changed

Lines changed: 6 additions & 4 deletions

File tree

profold2/model/alphafold2.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from einops import rearrange
1414

1515
from profold2.common import residue_constants
16-
from profold2.model import commons, functional
16+
from profold2.model import commons, folding, functional
1717
from profold2.model.evoformer import Evoformer
1818
from profold2.model.head import HeaderBuilder
1919
from profold2.utils import env, exists
@@ -259,9 +259,7 @@ def forward(
259259
if 'representations' in value:
260260
representations.update(value['representations'])
261261
if 'folding' in ret.headers:
262-
batch = functional.multi_chain_permutation_alignment(
263-
ret.headers['folding'], batch
264-
)
262+
batch = folding.multi_chain_permutation_alignment(ret.headers['folding'], batch)
265263
if self.training and compute_loss:
266264
for name, module, options in self.headers:
267265
if not hasattr(module, 'loss') or name not in ret.headers:

profold2/model/folding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,7 @@ def forward(self, representations, batch):
210210
)
211211

212212
return outputs
213+
214+
215+
def multi_chain_permutation_alignment(value, batch):
216+
return functional.multi_chain_permutation_alignment(value, batch)

0 commit comments

Comments
 (0)