Skip to content

Commit a339459

Browse files
authored
Merge branch 'beta' into beta
2 parents 525a481 + b0425cd commit a339459

12 files changed

Lines changed: 906 additions & 518 deletions

File tree

AlphaFold2.ipynb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@
227227
" # install dependencies\n",
228228
" # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n",
229229
" pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold@beta\"\n",
230+
" pip uninstall -yq jax jaxlib\n",
231+
" pip install -q \"jax[cuda]==0.3.25\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
230232
"\n",
231233
" # for debugging\n",
232234
" ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold\n",
@@ -270,6 +272,7 @@
270272
"num_recycles = \"auto\" #@param [\"auto\", \"0\", \"1\", \"3\", \"6\", \"12\", \"24\", \"48\"]\n",
271273
"recycle_early_stop_tolerance = \"auto\" #@param [\"auto\", \"0.0\", \"0.5\", \"1.0\"]\n",
272274
"#@markdown - if `auto` will use `num_recycles=20 tol=0.5` for `model_type=alphafold2_multimer_v3`, else `num_recyles=3 tol=0.5`\n",
275+
"set_cyclic_offset = False #@param {type:\"boolean\"}\n",
273276
"\n",
274277
"#@markdown ### Sample settings\n",
275278
"max_msa = \"auto\" #@param [\"auto\", \"512:1024\", \"256:512\", \"64:128\", \"32:64\", \"16:32\"]\n",
@@ -382,6 +385,7 @@
382385
" inputs_callback=inputs_callback,\n",
383386
" outputs_callback=outputs_callback,\n",
384387
" save_recycles=save_recycles,\n",
388+
" cyclic=set_cyclic_offset\n",
385389
")\n",
386390
"results_zip = f\"{jobname}.result.zip\"\n",
387391
"os.system(f\"zip -r {results_zip} {jobname}\")\n",

AlphaFold2_batch.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"id": "G4yBrceuFbf3"
3636
},
3737
"source": [
38-
"#ColabFold: AlphaFold2 w/ MMseqs2 BATCH\n",
38+
"#ColabFold v1.6.0: AlphaFold2 w/ MMseqs2 BATCH\n",
3939
"\n",
4040
"<img src=\"https://raw.githubusercontent.com/sokrypton/ColabFold/main/.github/ColabFold_Marv_Logo_Small.png\" height=\"256\" align=\"right\" style=\"height:256px\">\n",
4141
"\n",
@@ -120,8 +120,8 @@
120120
" # install dependencies\n",
121121
" # We have to use \"--no-warn-conflicts\" because colab already has a lot preinstalled with requirements different to ours\n",
122122
" pip install -q --no-warn-conflicts \"colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold@beta\"\n",
123-
" # high risk high gain\n",
124-
" pip install -q \"jax[cuda11_cudnn805]>=0.3.8,<0.4\" -f https://storage.googleapis.com/jax-releases/jax_releases.html\n",
123+
" pip uninstall -yq jax jaxlib\n",
124+
" pip install -q \"jax[cuda]==0.3.25\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
125125
" touch COLABFOLD_READY\n",
126126
"fi\n",
127127
"\n",

colabfold/alphafold/models.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
from alphafold.model.modules import AlphaFold
77
from alphafold.model.modules_multimer import AlphaFold as AlphaFoldMultimer
88

9-
109
def load_models_and_params(
1110
num_models: int,
1211
use_templates: bool,
1312
num_recycles: Optional[int] = None,
1413
recycle_early_stop_tolerance: Optional[float] = None,
1514
num_ensemble: int = 1,
15+
model_order: Optional[List[int]] = None,
1616
model_suffix: str = "_ptm",
1717
data_dir: Path = Path("."),
1818
stop_at_score: float = 100,
@@ -23,6 +23,7 @@ def load_models_and_params(
2323
use_fuse: bool = True,
2424
use_bfloat16: bool = True,
2525
use_dropout: bool = False,
26+
use_masking: bool = True,
2627
save_all: bool = False,
2728
) -> List[Tuple[str, model.RunModel, haiku.Params]]:
2829
"""We use only two actual models and swap the parameters to avoid recompiling.
@@ -34,7 +35,11 @@ def load_models_and_params(
3435
# Use only two model and later swap params to avoid recompiling
3536
model_runner_and_params: [Tuple[str, model.RunModel, haiku.Params]] = []
3637

37-
model_order = [1, 2, 3, 4, 5]
38+
if model_order is None:
39+
model_order = [1, 2, 3, 4, 5]
40+
else:
41+
model_order.sort()
42+
3843
model_build_order = [3, 4, 5, 1, 2]
3944
if "multimer" in model_suffix:
4045
models_need_compilation = [3]
@@ -77,6 +82,13 @@ def load_models_and_params(
7782
model_config.model.embeddings_and_evoformer.num_extra_msa = max_extra_seq
7883
else:
7984
model_config.data.common.max_extra_msa = max_extra_seq
85+
86+
# disable masking
87+
if not use_masking:
88+
if "multimer" in model_suffix:
89+
model_config.model.embeddings_and_evoformer.masked_msa.replace_fraction = 0.0
90+
else:
91+
model_config.data.eval.masked_msa_replace_fraction = 0.0
8092

8193
# disable some outputs if not being saved
8294
if not save_all:

colabfold/batch.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os
1+
import os,sys
22
ENV = {"TF_FORCE_UNIFIED_MEMORY":"1", "XLA_PYTHON_CLIENT_MEM_FRACTION":"4.0"}
33
for k,v in ENV.items():
44
if k not in os.environ: os.environ[k] = v
@@ -13,7 +13,7 @@
1313
from pathlib import Path
1414
import random
1515

16-
from colabfold.run_alphafold import run
16+
from colabfold.run_alphafold import run, set_model_type
1717
from colabfold.utils import (
1818
DEFAULT_API_SERVER, ACCEPT_DEFAULT_TERMS,
1919
get_commit, setup_logging
@@ -136,7 +136,7 @@ def main():
136136
choices=["auto", "plddt", "ptm", "iptm", "multimer"],
137137
)
138138
parser.add_argument("--pair-mode",
139-
help="rank models by auto, unpaired, paired, unpaired_paired",
139+
help="how to generate MSA for multimeric inputs: unpaired, paired, unpaired_paired",
140140
type=str,
141141
default="unpaired_paired",
142142
choices=["unpaired", "paired", "unpaired_paired"],
@@ -157,12 +157,16 @@ def main():
157157
action="store_true",
158158
help="saves the pair representation embeddings of all models",
159159
)
160-
parser.add_argument(
161-
"--use-dropout",
160+
parser.add_argument("--use-dropout",
162161
default=False,
163162
action="store_true",
164163
help="activate dropouts during inference to sample from uncertainity of the models",
165164
)
165+
parser.add_argument("--disable-masking",
166+
default=False,
167+
action="store_true",
168+
help='by default, 15% of the input MSA is randomly masked, set this flag to disable this',
169+
)
166170
parser.add_argument("--max-seq",
167171
help="number of sequence clusters to use",
168172
type=int,
@@ -203,6 +207,9 @@ def main():
203207
parser.add_argument("--interaction-scan", default=False, action="store_true")
204208
parser.add_argument("--disable-cluster-profile", default=False, action="store_true")
205209

210+
parser.add_argument("--cyclic", default=False, action="store_true")
211+
parser.add_argument("--save-best", default=False, action="store_true")
212+
206213
# backward compatability
207214
parser.add_argument('--training', default=False, action="store_true", help=argparse.SUPPRESS)
208215
parser.add_argument('--templates', default=False, action="store_true", help=argparse.SUPPRESS)
@@ -283,12 +290,15 @@ def main():
283290
save_single_representations=args.save_single_representations,
284291
save_pair_representations=args.save_pair_representations,
285292
use_dropout=args.use_dropout,
293+
use_masking=not args.disable_masking,
286294
max_seq=args.max_seq,
287295
max_extra_seq=args.max_extra_seq,
288296
use_cluster_profile=not args.disable_cluster_profile,
289-
use_gpu_relax = args.use_gpu_relax,
297+
use_gpu_relax=args.use_gpu_relax,
290298
save_all=args.save_all,
291299
save_recycles=args.save_recycles,
300+
cyclic=args.cyclic,
301+
save_best=args.save_best,
292302
)
293303

294304
if args.interaction_scan:

0 commit comments

Comments
 (0)