Skip to content

Commit 525a481

Browse files
author
dohyun-s
committed
Fix jax compilation error in unpaired mode.
1 parent 548fdcc commit 525a481

3 files changed

Lines changed: 4 additions & 3 deletions

File tree

colabfold/alphafold/msa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from alphafold.model.features import FeatureDict
66
from alphafold.model.tf import shape_placeholders
7+
import jax.numpy as jnp
78

89
NUM_RES = shape_placeholders.NUM_RES
910
NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
@@ -51,7 +52,7 @@ def make_fixed_size_multimer(
5152
NUM_RES = "num residues placeholder"
5253
NUM_MSA_SEQ = "msa placeholder"
5354
NUM_TEMPLATES = "num templates placeholder"
54-
msa_cluster_size = feat["bert_mask"].shape[0]
55+
msa_cluster_size = jnp.array(feat["bert_mask"]).shape[0]
5556
pad_size_map = {
5657
NUM_RES: num_res,
5758
NUM_MSA_SEQ: msa_cluster_size,

colabfold/batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def main():
319319
host_url=args.host_url,
320320
)
321321

322-
if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired":
322+
if args.pair_mode == "none" or "unpaired" or "unpaired_paired":
323323
unpaired_path = Path(args.results).joinpath(str(jobname)+"_unpaired_env")
324324
unpaired_a3m_lines = run_mmseqs2(
325325
query_seqs_unique,

colabfold/predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def predict_structure(
6969
input_features = feature_dict
7070
input_features["asym_id"] = input_features["asym_id"] - input_features["asym_id"][...,0]
7171
# TODO
72+
input_features = pad_input_multimer(input_features, model_runner, model_name, pad_len, use_templates)
7273
if seq_len < pad_len:
73-
input_features = pad_input_multimer(input_features, model_runner, model_name, pad_len, use_templates)
7474
logger.info(f"Padding length to {pad_len}")
7575
else:
7676
if model_num == 0:

0 commit comments

Comments
 (0)