Skip to content

Commit 75aab93

Browse files
committed
benchmarking updates added
1 parent ad0d50c commit 75aab93

9 files changed

Lines changed: 2212 additions & 32 deletions

AFfine/ig_pipeline.py

Lines changed: 573 additions & 0 deletions
Large diffs are not rendered by default.

AFfine/predict_utils_ig.py

Lines changed: 377 additions & 0 deletions
Large diffs are not rendered by default.

AFfine/rank_pep_plddt.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import numpy as np
2+
import json
3+
import sys
4+
import os
5+
6+
def rank_peptide_plddt(npz_path, query_chainseq, output_path=None):
7+
"""
8+
Rank sampled structures by mean peptide pLDDT.
9+
10+
Args:
11+
npz_path: path to *_sampling_results.npz
12+
query_chainseq: e.g. "MKTL.../AVSL..." (peptide is last chain)
13+
output_path: where to save JSON (default: same dir as npz)
14+
"""
15+
data = np.load(npz_path)
16+
all_plddt = data['all_plddt'] # [N_samples, N_res_padded]
17+
plddt_mean = data.get('plddt_mean') # [N_res_padded] from single prediction
18+
19+
chain_lens = [len(c) for c in query_chainseq.split('/')]
20+
pep_start = sum(chain_lens[:-1])
21+
pep_end = sum(chain_lens)
22+
23+
# Non-sampled: use the first prediction's pLDDT (or plddt_mean)
24+
# plddt_mean is average across samples, so use all_plddt[0] isn't right either
25+
# Better: load from the regular prediction. But plddt_mean works as proxy.
26+
non_sampled_pep = float(np.mean(plddt_mean[pep_start:pep_end]))
27+
28+
# Per-sample peptide mean pLDDT
29+
sampled = {}
30+
for i in range(all_plddt.shape[0]):
31+
pep_plddt = float(np.mean(all_plddt[i, pep_start:pep_end]))
32+
sampled[i] = round(pep_plddt, 3)
33+
34+
# Sort by pLDDT descending
35+
ranked = dict(sorted(sampled.items(), key=lambda x: x[1], reverse=True))
36+
37+
result = {
38+
"non_sampled_peptide_plddt": round(non_sampled_pep, 3),
39+
"n_samples": len(ranked),
40+
"sampled_ranked": {str(k): v for k, v in ranked.items()},
41+
}
42+
43+
if output_path is None:
44+
output_path = npz_path.replace('_sampling_results.npz', '_pep_plddt_ranked.json')
45+
46+
with open(output_path, 'w') as f:
47+
json.dump(result, f, indent=2)
48+
print(f"Saved: {output_path}")
49+
return result
50+
51+
52+
if __name__ == '__main__':
53+
npz_path = sys.argv[1]
54+
chainseq = sys.argv[2]
55+
out = sys.argv[3] if len(sys.argv) > 3 else None
56+
rank_peptide_plddt(npz_path, chainseq, out)

AFfine/run_prediction.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717
import warnings
1818
warnings.filterwarnings('ignore', category=FutureWarning)
1919
warnings.filterwarnings('ignore', category=DeprecationWarning)
20+
from rank_pep_plddt import rank_peptide_plddt
2021
###
2122
import argparse
23+
from run_prediction_ig_patch import (
24+
add_ig_pipeline_args, setup_ig_pipeline, process_target_with_ig_pipeline,
25+
)
2226

2327
parser = argparse.ArgumentParser(
2428
description="Run simple template-based alphafold inference",
@@ -71,6 +75,10 @@
7175
parser.add_argument('--no_initial_guess', action='store_true', default=False, help='When active, no intial guess is used to direct modeling and only template is used.')
7276
parser.add_argument('--return_all_outputs', action='store_true', default=False, help='Save all alphafold outputs including evoformer output')
7377
parser.add_argument('--use_msa', action='store_true', default=False, help='If Enabled, use MSA for prediction. If not, only template is used.')
78+
parser = add_ig_pipeline_args(parser)
79+
parser.add_argument('--pep_sampling', type=str, default=None,
80+
help='Peptide sampling scope: "all", "anchors", or comma-separated '
81+
'1-indexed peptide positions e.g. "2,5,9"')
7482
args = parser.parse_args()
7583

7684
import os
@@ -108,7 +116,7 @@
108116
num_recycle = args.num_recycles[0],
109117
args = args
110118
)
111-
119+
ig_config = setup_ig_pipeline(args)
112120
final_dfl = []
113121
for counter, targetl in targets.iterrows():
114122
print('START:', counter, 'of', targets.shape[0])
@@ -191,25 +199,25 @@
191199
msa = [query_sequence] + msa
192200

193201

194-
195-
196-
all_metrics = predict_utils.run_alphafold_prediction(
197-
query_sequence=query_sequence,
198-
msa=msa,
199-
deletion_matrix=deletion_matrix,
200-
chainbreak_sequence=query_chainseq,
201-
template_features=all_template_features,
202-
model_runners=model_runners,
203-
out_prefix=outfile_prefix,
204-
crop_size=crop_size,
205-
dump_pdbs = not (args.no_pdbs or args.terse),
206-
dump_metrics = not args.terse,
207-
template_pdb_dict = template_pdb_dict, # added by Amir for getting pandora data
208-
no_initial_guess=args.no_initial_guess,
209-
return_all_outputs=args.return_all_outputs
202+
# ── Build peptide mask per-target if sampling requested ──
203+
all_metrics = process_target_with_ig_pipeline(
204+
args, ig_config, targetl, query_sequence, query_chainseq,
205+
all_template_features, model_runners, outfile_prefix,
206+
crop_size, msa, deletion_matrix,
207+
)
208+
# -------- End of V2 after sampling mode -------------- #
209+
all_metrics = process_target_with_ig_pipeline(
210+
args, ig_config, targetl, query_sequence, query_chainseq,
211+
all_template_features, model_runners, outfile_prefix,
212+
crop_size, msa, deletion_matrix,
210213
)
211214

212-
215+
# ── Rank sampled structures by peptide pLDDT ──
216+
if getattr(args, 'pep_sampling', None) is not None:
217+
for model_name in args.model_names:
218+
npz_path = f'{outfile_prefix}_{model_name}_sampling_results.npz'
219+
if os.path.exists(npz_path):
220+
rank_peptide_plddt(npz_path, query_chainseq)
213221
outl = targetl.copy()
214222
for model_name, metrics in all_metrics.items():
215223
plddts = metrics['plddt']

0 commit comments

Comments
 (0)