@@ -47,7 +47,6 @@ def predict_structure(
4747
4848 mean_scores = []
4949 conf = []
50- unrelaxed_pdb_lines = []
5150 prediction_times = []
5251 seq_len = sum (sequences_lengths )
5352 model_names = []
@@ -103,8 +102,7 @@ def callback(result, recycles):
103102 result = result , b_factors = b_factors ,
104103 remove_leading_feature_dimension = ("multimer" not in model_type ))
105104
106- unrelaxed_pdb_lines = protein .to_pdb (unrelaxed_protein )
107- files .get ("unrelaxed" ,f"r{ recycles } .pdb" ).write_text (unrelaxed_pdb_lines )
105+ files .get ("unrelaxed" ,f"r{ recycles } .pdb" ).write_text (protein .to_pdb (unrelaxed_protein ))
108106
109107 if save_all :
110108 with files .get ("all" ,f"r{ recycles } .pickle" ).open ("wb" ) as handle :
@@ -113,83 +111,110 @@ def callback(result, recycles):
113111 ########################
114112 # predict
115113 ########################
116- start = time .time ()
117- result , recycles = \
118- model_runner .predict (input_features ,
119- random_seed = seed ,
120- return_representations = (save_all or save_single_representations or save_pair_representations ),
121- callback = callback )
122- prediction_times .append (time .time () - start )
114+ scores_path = files .get ("scores" ,"json" )
115+ if scores_path .is_file ():
116+ # if score file already exists, log scores and continue
117+
118+ with scores_path .open ("r" ) as handle :
119+ scores = json .load (handle )
120+
121+ mean_scores .append (float (scores ["ranking_confidence" ]))
122+
123+ print_line = ""
124+ conf .append ({})
125+ for x ,y in [["plddt" ,"pLDDT" ],["ptm" ,"pTM" ],["iptm" ,"ipTM" ]]:
126+ if x in scores :
127+ conf [- 1 ][x ] = float (np .mean (scores [x ]) if x == "plddt" else scores [x ])
128+ print_line += f" { y } ={ conf [- 1 ][x ]:.3g} "
129+ conf [- 1 ]["print_line" ] = print_line
130+ logger .info (f"{ tag } { print_line } " )
123131
124- ########################
125- # parse results
126- ########################
127- # summary metrics
128- mean_scores .append (result ["ranking_confidence" ])
129- if recycles == 0 : result .pop ("tol" ,None )
130- if not is_complex : result .pop ("iptm" ,None )
131- print_line = ""
132- conf .append ({})
133- for x ,y in [["mean_plddt" ,"pLDDT" ],["ptm" ,"pTM" ],["iptm" ,"ipTM" ]]:
134- if x in result :
135- print_line += f" { y } ={ result [x ]:.3g} "
136- conf [- 1 ][x ] = float (result [x ])
137- conf [- 1 ]["print_line" ] = print_line
138- logger .info (f"{ tag } took { prediction_times [- 1 ]:.1f} s ({ recycles } recycles)" )
139-
140- # create protein object
141- final_atom_mask = result ["structure_module" ]["final_atom_mask" ]
142- b_factors = result ["plddt" ][:, None ] * final_atom_mask
143- unrelaxed_protein = protein .from_prediction (
144- features = input_features ,
145- result = result ,
146- b_factors = b_factors ,
147- remove_leading_feature_dimension = ("multimer" not in model_type ))
132+ files .get ("unrelaxed" ,"pdb" )
133+ if save_all : files .get ("all" ,"pickle" )
134+ if save_single_representations : files .get ("single_repr" ,"npy" )
135+ if save_pair_representations : files .get ("pair_repr" ,"npy" )
148136
149- #########################
150- # save results
151- #########################
152- # save pdb
153- protein_lines = protein .to_pdb (unrelaxed_protein )
154- files .get ("unrelaxed" ,"pdb" ).write_text (protein_lines )
155- unrelaxed_pdb_lines .append (protein_lines )
156-
157- # save raw outputs
158- if save_all :
159- with files .get ("all" ,"pickle" ).open ("wb" ) as handle :
160- pickle .dump (result , handle )
161- if save_single_representations :
162- np .save (files .get ("single_repr" ,"npy" ),result ["representations" ]["single" ])
163- if save_pair_representations :
164- np .save (files .get ("pair_repr" ,"npy" ),result ["representations" ]["pair" ])
165-
166- # write an easy-to-use format (pAE and pLDDT)
167- with files .get ("scores" ,"json" ).open ("w" ) as handle :
168- plddt = result ["plddt" ][:seq_len ]
169- scores = {"plddt" : np .around (plddt .astype (float ), 2 ).tolist ()}
170- if "predicted_aligned_error" in result :
171- pae = result ["predicted_aligned_error" ][:seq_len ,:seq_len ]
172- scores .update ({"max_pae" : pae .max ().astype (float ).item (),
173- "pae" : np .around (pae .astype (float ), 2 ).tolist ()})
174- for k in ["ptm" ,"iptm" ]:
175- if k in conf [- 1 ]: scores [k ] = np .around (conf [- 1 ][k ], 2 ).item ()
176- del pae
177- del plddt
178- json .dump (scores , handle )
179-
180- ###############################
181- # callback for visualization
182- ###############################
183- if outputs_callback is not None :
184- outputs_callback ({
185- "unrelaxed_protein" :unrelaxed_protein ,
186- "sequences_lengths" :sequences_lengths ,
187- "result" :result ,
188- "input_features" :input_features ,
189- "tag" :tag ,
190- "files" :files .files [tag ]})
191-
192- del result , unrelaxed_protein
137+ else :
138+ ###########################################################
139+ start = time .time ()
140+ result , recycles = \
141+ model_runner .predict (input_features ,
142+ random_seed = seed ,
143+ return_representations = (save_all or save_single_representations or save_pair_representations ),
144+ callback = callback )
145+ prediction_times .append (time .time () - start )
146+
147+ ########################
148+ # parse results
149+ ########################
150+ # summary metrics
151+ mean_scores .append (result ["ranking_confidence" ])
152+ if recycles == 0 : result .pop ("tol" ,None )
153+ if not is_complex : result .pop ("iptm" ,None )
154+ print_line = ""
155+ conf .append ({})
156+ for x ,y in [["mean_plddt" ,"pLDDT" ],["ptm" ,"pTM" ],["iptm" ,"ipTM" ]]:
157+ if x in result :
158+ print_line += f" { y } ={ result [x ]:.3g} "
159+ conf [- 1 ][x ] = float (result [x ])
160+ conf [- 1 ]["print_line" ] = print_line
161+ logger .info (f"{ tag } took { prediction_times [- 1 ]:.1f} s ({ recycles } recycles)" )
162+
163+ # create protein object
164+ final_atom_mask = result ["structure_module" ]["final_atom_mask" ]
165+ b_factors = result ["plddt" ][:, None ] * final_atom_mask
166+ unrelaxed_protein = protein .from_prediction (
167+ features = input_features ,
168+ result = result ,
169+ b_factors = b_factors ,
170+ remove_leading_feature_dimension = ("multimer" not in model_type ))
171+
172+ #########################
173+ # save results
174+ #########################
175+ # save pdb
176+ files .get ("unrelaxed" ,"pdb" ).write_text (protein .to_pdb (unrelaxed_protein ))
177+
178+ # save raw outputs
179+ if save_all :
180+ with files .get ("all" ,"pickle" ).open ("wb" ) as handle :
181+ pickle .dump (result , handle )
182+ if save_single_representations :
183+ np .save (files .get ("single_repr" ,"npy" ),result ["representations" ]["single" ])
184+ if save_pair_representations :
185+ np .save (files .get ("pair_repr" ,"npy" ),result ["representations" ]["pair" ])
186+
187+ # write an easy-to-use format (pAE and pLDDT)
188+ with scores_path .open ("w" ) as handle :
189+ plddt = result ["plddt" ][:seq_len ]
190+ scores = {
191+ "plddt" : np .around (plddt .astype (float ), 2 ).tolist (),
192+ "ranking_confidence" : np .around (result ["ranking_confidence" ], 2 ).tolist ()
193+ }
194+ if "predicted_aligned_error" in result :
195+ pae = result ["predicted_aligned_error" ][:seq_len ,:seq_len ]
196+ scores .update ({"max_pae" : pae .max ().astype (float ).item (),
197+ "pae" : np .around (pae .astype (float ), 2 ).tolist ()})
198+ for k in ["ptm" ,"iptm" ]:
199+ if k in conf [- 1 ]: scores [k ] = np .around (conf [- 1 ][k ], 2 ).item ()
200+ del pae
201+ del plddt
202+ json .dump (scores , handle )
203+
204+ ###############################
205+ # callback for visualization
206+ ###############################
207+ if outputs_callback is not None :
208+ outputs_callback ({
209+ "unrelaxed_protein" :unrelaxed_protein ,
210+ "sequences_lengths" :sequences_lengths ,
211+ "result" :result ,
212+ "input_features" :input_features ,
213+ "tag" :tag ,
214+ "files" :files .files [tag ]})
215+
216+ del result , unrelaxed_protein
217+ ###########################################################
193218
194219 # early stop criteria fulfilled
195220 if mean_scores [- 1 ] > stop_at_score : break
@@ -215,7 +240,8 @@ def callback(result, recycles):
215240 # save relaxed pdb
216241 if n < num_relax :
217242 start = time .time ()
218- pdb_lines = run_relax (pdb_lines = unrelaxed_pdb_lines [key ], use_gpu = use_gpu_relax )
243+ pdb_filename = result_dir .joinpath (f"{ prefix } _unrelaxed_{ tag } .pdb" )
244+ pdb_lines = run_relax (pdb_filename = pdb_filename , use_gpu = use_gpu_relax )
219245 files .get ("relaxed" ,"pdb" ).write_text (pdb_lines )
220246 logger .info (f"Relaxation took { (time .time () - start ):.1f} s" )
221247
0 commit comments