11"""Community-based Boltz models for complex structure prediction with ligands/dna/rna."""
22
33import warnings
4- from typing import Sequence
4+ from typing import Mapping , Sequence , cast
55
66from pydantic import BaseModel , Field , TypeAdapter , model_validator
77
1010from openprotein .common import ModelMetadata
1111from openprotein .fold .common import normalize_inputs , serialize_input
1212from openprotein .molecules import Complex , Ligand , Protein
13+ from openprotein .molecules .template import Template
14+ from openprotein .prompt import PromptAPI
1315
1416from . import api
1517from .complex import id_generator
@@ -40,7 +42,7 @@ def fold(
4042 num_steps : int = 200 ,
4143 step_scale : float = 1.638 ,
4244 use_potentials : bool = False ,
43- constraints : list [ dict ] | None = None ,
45+ constraints : Sequence [ Mapping ] | None = None ,
4446 ** kwargs ,
4547 ) -> FoldResultFuture :
4648 """
@@ -83,19 +85,9 @@ def fold(
8385
8486 # build the normalized_models from msa
8587 if isinstance (sequences , MSAFuture ):
86- id_gen = id_generator ()
87- align_api = getattr (self .session , "align" , None )
88- assert isinstance (align_api , AlignAPI )
89- msa = sequences # rename
90- seed = align_api .get_seed (job_id = msa .job .job_id )
91- _proteins : dict [str , Protein ] = {}
92- for seq in seed .split (":" ):
93- protein = Protein (sequence = seq )
94- id = next (id_gen )
95- protein .msa = msa .id
96- _proteins [id ] = protein
97- normalized_complexes = [Complex (chains = _proteins )]
98-
88+ normalized_complexes = [
89+ _msa_future_to_complex (session = self .session , msa = sequences )
90+ ]
9991 else :
10092 normalized_complexes = normalize_inputs (sequences )
10193
@@ -139,9 +131,9 @@ def fold(
139131 num_steps : int = 200 ,
140132 step_scale : float = 1.638 ,
141133 use_potentials : bool = False ,
142- constraints : list [ dict ] | None = None ,
143- templates : list [ dict ] | None = None ,
144- properties : list [ dict ] | None = None ,
134+ constraints : Sequence [ Mapping ] | None = None ,
135+ templates : Sequence [ Protein | Complex | Template ] | None = None ,
136+ properties : Sequence [ Mapping ] | None = None ,
145137 method : str | None = None ,
146138 ) -> FoldResultFuture :
147139 """
@@ -163,7 +155,7 @@ def fold(
163155 Whether or not to use potentials.
164156 constraints : list[dict] | None = None
165157 List of constraints.
166- templates: list[dict ] | None = None
158+ templates: list[Protein | Complex | Template ] | None = None
167159 List of templates to use for structure prediction.
168160 properties: list[dict] | None = None
169161 List of additional properties to predict. Should match the `BoltzProperties`
@@ -180,24 +172,98 @@ def fold(
180172 Returns
181173 -------
182174 FoldResultFuture
183- Future for the folding result.
175+ Future for the folding result.
184176 """
185-
177+ prompt_api = getattr (self .session , "prompt" , None )
178+ assert isinstance (prompt_api , PromptAPI )
179+
180+ # validate templates
181+ # mapping chain_id (to predict) to template
182+ # needs to be consistent
183+ templates_ : list [Template ] = []
184+ if not isinstance (sequences , MSAFuture ):
185+ first_chain_id_to_template = {}
186+ for batch_idx , seq in enumerate (sequences ):
187+ # validate templates and normalize to complex
188+ if isinstance (seq , str ) or isinstance (seq , bytes ):
189+ seq = Protein (seq )
190+ seq ._assert_valid_templates ()
191+ if isinstance (seq , Protein ):
192+ complex = Complex ({"A" : seq })
193+ else :
194+ complex = seq
195+ # resolve chain-level templates
196+ for chain_id , protein in complex .get_proteins ().items ():
197+ # Verify same chain_id should have same templates
198+ if batch_idx == 0 :
199+ first_chain_id_to_template [chain_id ] = protein .templates
200+ for template in protein .templates :
201+ templates_ .append (_to_template (template , chain_id = chain_id ))
202+ elif first_chain_id_to_template [chain_id ] != protein .templates :
203+ raise ValueError (
204+ "Expected same chain across batches to have the same templates"
205+ )
206+ # resolve complex-level templates
207+ if batch_idx == 0 :
208+ first_templates = complex .templates
209+ for template in complex .templates :
210+ templates_ .append (_to_template (template ))
211+ elif first_templates != complex .templates :
212+ raise ValueError (
213+ "Expected templates across complexes in batch to be the same"
214+ )
215+ # method level argument
186216 if templates is not None :
187- raise ValueError ("`templates` not yet supported!" )
217+ if isinstance (sequences , MSAFuture ):
218+ # need to convert to complex for template validation
219+ sequences = [
220+ _msa_future_to_complex (session = self .session , msa = sequences )
221+ ]
222+ for template in templates :
223+ template = _to_template (template )
224+ # validate the template for all sequences before accepting it
225+ for seq in sequences :
226+ if isinstance (seq , str ) or isinstance (seq , bytes ):
227+ seq = Protein (seq )
228+ template .validate_for_target (seq )
229+ templates_ .append (template )
230+
231+ # resolve list of Templates into expected dict arg
232+ template_dicts : list [dict ] = []
233+ # track resolved queries to reduce network calls - use id() for identity-based caching
234+ struct_id_to_query_id = {}
235+
236+ for template in templates_ :
237+ # Use id() for caching - only resolve each unique structure once
238+ struct_id = id (template .template )
239+ if struct_id not in struct_id_to_query_id :
240+ struct_id_to_query_id [struct_id ] = prompt_api ._resolve_query (
241+ query = template .template
242+ )
243+
244+ template_dict = {"query_id" : struct_id_to_query_id [struct_id ]}
245+
246+ if template .mapping is not None :
247+ if isinstance (template .mapping , str ):
248+ template_dict ["chain_id" ] = template .mapping
249+ else :
250+ template_dict ["chain_id" ] = list (template .mapping .values ())
251+ template_dict ["template_id" ] = list (template .mapping .keys ())
252+
253+ template_dicts .append (template_dict )
188254
189255 # validate properties
190256 if properties is not None :
191257 props = TypeAdapter (list [BoltzProperty ]).validate_python (properties )
192258 # Only allow affinity for ligands, and check binder refers to a ligand chain_id (str, not list)
193259 ligand_chain_ids = set ()
194- if isinstance (sequences , list ):
260+ if not isinstance (sequences , MSAFuture ):
195261 for protein in sequences :
196262 if isinstance (protein , Complex ):
197263 complex = protein
198- for id , chain in complex .get_chains ().items ():
264+ for chain_id , chain in complex .get_chains ().items ():
199265 if isinstance (chain , Ligand ):
200- ligand_chain_ids .add (id )
266+ ligand_chain_ids .add (chain_id )
201267 for prop in props :
202268 if hasattr (prop , "affinity" ) and prop .affinity is not None :
203269 binder_id = prop .affinity .binder
@@ -214,7 +280,7 @@ def fold(
214280 step_scale = step_scale ,
215281 use_potentials = use_potentials ,
216282 constraints = constraints ,
217- templates = templates ,
283+ templates = template_dicts or None ,
218284 properties = properties ,
219285 method = method ,
220286 )
@@ -235,7 +301,7 @@ def fold(
235301 num_steps : int = 200 ,
236302 step_scale : float = 1.638 ,
237303 use_potentials : bool = False ,
238- constraints : list [ dict ] | None = None ,
304+ constraints : Sequence [ Mapping ] | None = None ,
239305 ) -> FoldResultFuture :
240306 """
241307 Request structure prediction with Boltz-1 model.
@@ -305,7 +371,7 @@ def fold(
305371 num_recycles : int = 3 ,
306372 num_steps : int = 200 ,
307373 step_scale : float = 1.638 ,
308- constraints : list [ dict ] | None = None ,
374+ constraints : Sequence [ Mapping ] | None = None ,
309375 ) -> FoldResultFuture :
310376 """
311377 Request structure prediction with Boltz-1x model. Uses potentials with Boltz-1 model.
@@ -516,3 +582,21 @@ class BoltzAffinity(BaseModel):
516582
517583 class Config :
518584 extra = "allow" # Allow extra fields
585+
586+
587+ def _msa_future_to_complex (session : APISession , msa : MSAFuture ) -> Complex :
588+ align_api = getattr (session , "align" , None )
589+ assert isinstance (align_api , AlignAPI )
590+ seed = align_api .get_seed (job_id = msa .job .job_id )
591+ proteins : dict [str , Protein ] = {}
592+ for chain_id , seq in zip (id_generator (), seed .split (":" )):
593+ protein = Protein (sequence = seq )
594+ protein .msa = msa .id
595+ proteins [chain_id ] = protein
596+ return Complex (chains = proteins )
597+
598+
599+ def _to_template (obj , chain_id : str | None = None ):
600+ if not isinstance (obj , Template ):
601+ obj = Template (template = obj , mapping = chain_id )
602+ return obj
0 commit comments