Skip to content

Commit 0b33d85

Browse files
committed
add logic to skip models already finished
1 parent 5362d51 commit 0b33d85

2 files changed

Lines changed: 106 additions & 80 deletions

File tree

colabfold/alphafold/relax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#############
22
# relax functions
33
#############
4-
4+
from pathlib import Path
55
from alphafold.relax import relax
66
from alphafold.common import protein
77

colabfold/predict.py

Lines changed: 105 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def predict_structure(
4747

4848
mean_scores = []
4949
conf = []
50-
unrelaxed_pdb_lines = []
5150
prediction_times = []
5251
seq_len = sum(sequences_lengths)
5352
model_names = []
@@ -103,8 +102,7 @@ def callback(result, recycles):
103102
result=result, b_factors=b_factors,
104103
remove_leading_feature_dimension=("multimer" not in model_type))
105104

106-
unrelaxed_pdb_lines = protein.to_pdb(unrelaxed_protein)
107-
files.get("unrelaxed",f"r{recycles}.pdb").write_text(unrelaxed_pdb_lines)
105+
files.get("unrelaxed",f"r{recycles}.pdb").write_text(protein.to_pdb(unrelaxed_protein))
108106

109107
if save_all:
110108
with files.get("all",f"r{recycles}.pickle").open("wb") as handle:
@@ -113,83 +111,110 @@ def callback(result, recycles):
113111
########################
114112
# predict
115113
########################
116-
start = time.time()
117-
result, recycles = \
118-
model_runner.predict(input_features,
119-
random_seed=seed,
120-
return_representations=(save_all or save_single_representations or save_pair_representations),
121-
callback=callback)
122-
prediction_times.append(time.time() - start)
114+
scores_path = files.get("scores","json")
115+
if scores_path.is_file():
116+
# if score file already exists, log scores and continue
117+
118+
with scores_path.open("r") as handle:
119+
scores = json.load(handle)
120+
121+
mean_scores.append(float(scores["ranking_confidence"]))
122+
123+
print_line = ""
124+
conf.append({})
125+
for x,y in [["plddt","pLDDT"],["ptm","pTM"],["iptm","ipTM"]]:
126+
if x in scores:
127+
conf[-1][x] = float(np.mean(scores[x]) if x == "plddt" else scores[x])
128+
print_line += f" {y}={conf[-1][x]:.3g}"
129+
conf[-1]["print_line"] = print_line
130+
logger.info(f"{tag}{print_line}")
123131

124-
########################
125-
# parse results
126-
########################
127-
# summary metrics
128-
mean_scores.append(result["ranking_confidence"])
129-
if recycles == 0: result.pop("tol",None)
130-
if not is_complex: result.pop("iptm",None)
131-
print_line = ""
132-
conf.append({})
133-
for x,y in [["mean_plddt","pLDDT"],["ptm","pTM"],["iptm","ipTM"]]:
134-
if x in result:
135-
print_line += f" {y}={result[x]:.3g}"
136-
conf[-1][x] = float(result[x])
137-
conf[-1]["print_line"] = print_line
138-
logger.info(f"{tag} took {prediction_times[-1]:.1f}s ({recycles} recycles)")
139-
140-
# create protein object
141-
final_atom_mask = result["structure_module"]["final_atom_mask"]
142-
b_factors = result["plddt"][:, None] * final_atom_mask
143-
unrelaxed_protein = protein.from_prediction(
144-
features=input_features,
145-
result=result,
146-
b_factors=b_factors,
147-
remove_leading_feature_dimension=("multimer" not in model_type))
132+
files.get("unrelaxed","pdb")
133+
if save_all: files.get("all","pickle")
134+
if save_single_representations: files.get("single_repr","npy")
135+
if save_pair_representations: files.get("pair_repr","npy")
148136

149-
#########################
150-
# save results
151-
#########################
152-
# save pdb
153-
protein_lines = protein.to_pdb(unrelaxed_protein)
154-
files.get("unrelaxed","pdb").write_text(protein_lines)
155-
unrelaxed_pdb_lines.append(protein_lines)
156-
157-
# save raw outputs
158-
if save_all:
159-
with files.get("all","pickle").open("wb") as handle:
160-
pickle.dump(result, handle)
161-
if save_single_representations:
162-
np.save(files.get("single_repr","npy"),result["representations"]["single"])
163-
if save_pair_representations:
164-
np.save(files.get("pair_repr","npy"),result["representations"]["pair"])
165-
166-
# write an easy-to-use format (pAE and pLDDT)
167-
with files.get("scores","json").open("w") as handle:
168-
plddt = result["plddt"][:seq_len]
169-
scores = {"plddt": np.around(plddt.astype(float), 2).tolist()}
170-
if "predicted_aligned_error" in result:
171-
pae = result["predicted_aligned_error"][:seq_len,:seq_len]
172-
scores.update({"max_pae": pae.max().astype(float).item(),
173-
"pae": np.around(pae.astype(float), 2).tolist()})
174-
for k in ["ptm","iptm"]:
175-
if k in conf[-1]: scores[k] = np.around(conf[-1][k], 2).item()
176-
del pae
177-
del plddt
178-
json.dump(scores, handle)
179-
180-
###############################
181-
# callback for visualization
182-
###############################
183-
if outputs_callback is not None:
184-
outputs_callback({
185-
"unrelaxed_protein":unrelaxed_protein,
186-
"sequences_lengths":sequences_lengths,
187-
"result":result,
188-
"input_features":input_features,
189-
"tag":tag,
190-
"files":files.files[tag]})
191-
192-
del result, unrelaxed_protein
137+
else:
138+
###########################################################
139+
start = time.time()
140+
result, recycles = \
141+
model_runner.predict(input_features,
142+
random_seed=seed,
143+
return_representations=(save_all or save_single_representations or save_pair_representations),
144+
callback=callback)
145+
prediction_times.append(time.time() - start)
146+
147+
########################
148+
# parse results
149+
########################
150+
# summary metrics
151+
mean_scores.append(result["ranking_confidence"])
152+
if recycles == 0: result.pop("tol",None)
153+
if not is_complex: result.pop("iptm",None)
154+
print_line = ""
155+
conf.append({})
156+
for x,y in [["mean_plddt","pLDDT"],["ptm","pTM"],["iptm","ipTM"]]:
157+
if x in result:
158+
print_line += f" {y}={result[x]:.3g}"
159+
conf[-1][x] = float(result[x])
160+
conf[-1]["print_line"] = print_line
161+
logger.info(f"{tag} took {prediction_times[-1]:.1f}s ({recycles} recycles)")
162+
163+
# create protein object
164+
final_atom_mask = result["structure_module"]["final_atom_mask"]
165+
b_factors = result["plddt"][:, None] * final_atom_mask
166+
unrelaxed_protein = protein.from_prediction(
167+
features=input_features,
168+
result=result,
169+
b_factors=b_factors,
170+
remove_leading_feature_dimension=("multimer" not in model_type))
171+
172+
#########################
173+
# save results
174+
#########################
175+
# save pdb
176+
files.get("unrelaxed","pdb").write_text(protein.to_pdb(unrelaxed_protein))
177+
178+
# save raw outputs
179+
if save_all:
180+
with files.get("all","pickle").open("wb") as handle:
181+
pickle.dump(result, handle)
182+
if save_single_representations:
183+
np.save(files.get("single_repr","npy"),result["representations"]["single"])
184+
if save_pair_representations:
185+
np.save(files.get("pair_repr","npy"),result["representations"]["pair"])
186+
187+
# write an easy-to-use format (pAE and pLDDT)
188+
with scores_path.open("w") as handle:
189+
plddt = result["plddt"][:seq_len]
190+
scores = {
191+
"plddt": np.around(plddt.astype(float), 2).tolist(),
192+
"ranking_confidence": np.around(result["ranking_confidence"], 2).tolist()
193+
}
194+
if "predicted_aligned_error" in result:
195+
pae = result["predicted_aligned_error"][:seq_len,:seq_len]
196+
scores.update({"max_pae": pae.max().astype(float).item(),
197+
"pae": np.around(pae.astype(float), 2).tolist()})
198+
for k in ["ptm","iptm"]:
199+
if k in conf[-1]: scores[k] = np.around(conf[-1][k], 2).item()
200+
del pae
201+
del plddt
202+
json.dump(scores, handle)
203+
204+
###############################
205+
# callback for visualization
206+
###############################
207+
if outputs_callback is not None:
208+
outputs_callback({
209+
"unrelaxed_protein":unrelaxed_protein,
210+
"sequences_lengths":sequences_lengths,
211+
"result":result,
212+
"input_features":input_features,
213+
"tag":tag,
214+
"files":files.files[tag]})
215+
216+
del result, unrelaxed_protein
217+
###########################################################
193218

194219
# early stop criteria fulfilled
195220
if mean_scores[-1] > stop_at_score: break
@@ -215,7 +240,8 @@ def callback(result, recycles):
215240
# save relaxed pdb
216241
if n < num_relax:
217242
start = time.time()
218-
pdb_lines = run_relax(pdb_lines=unrelaxed_pdb_lines[key], use_gpu=use_gpu_relax)
243+
pdb_filename = result_dir.joinpath(f"{prefix}_unrelaxed_{tag}.pdb")
244+
pdb_lines = run_relax(pdb_filename=pdb_filename, use_gpu=use_gpu_relax)
219245
files.get("relaxed","pdb").write_text(pdb_lines)
220246
logger.info(f"Relaxation took {(time.time() - start):.1f}s")
221247

0 commit comments

Comments
 (0)