Skip to content

Commit 583ffcb

Browse files
committed
Release v0.11.0
1 parent 39619f8 commit 583ffcb

File tree

11 files changed

+401
-58
lines changed

11 files changed

+401
-58
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,6 @@ __marimo__/
367367

368368
.envrc
369369
/.direnv/
370+
371+
scratch.org
372+
gptel

openprotein/embeddings/poet.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import TYPE_CHECKING
44

55
from openprotein.base import APISession
6-
from openprotein.common import ModelMetadata, ReductionType
6+
from openprotein.common import ModelMetadata, Reduction, ReductionType
77
from openprotein.data import AssayDataset, AssayMetadata
88
from openprotein.prompt import Prompt
99

@@ -51,9 +51,9 @@ def __init__(
5151

5252
def embed(
5353
self,
54-
sequences: list[bytes],
55-
prompt: str | Prompt | None = None,
54+
sequences: list[bytes] | list[str],
5655
reduction: ReductionType | None = ReductionType.MEAN,
56+
prompt: str | Prompt | None = None,
5757
**kwargs,
5858
) -> EmbeddingsResultFuture:
5959
"""
@@ -74,6 +74,12 @@ def embed(
7474
-------
7575
EmbeddingsResultFuture
7676
Future object that returns the embeddings of the submitted sequences.
77+
78+
Note: The embeddings for PoET can have an extra first dimension if using ensemble
79+
prompts, where the first dimension is the number of replicates in the ensemble
80+
prompt. i.e. the shape is ``(N, L, D)`` if ``N`` > 1 else ``(L, D)`` where ``N`` is
81+
the number of replicates in the prompt, ``L`` is the length of the sequence, ``D`` is
82+
the dimensions of the ensemble.
7783
"""
7884
if prompt is None:
7985
prompt_id = None
@@ -88,7 +94,7 @@ def embed(
8894

8995
def logits(
9096
self,
91-
sequences: list[bytes],
97+
sequences: list[bytes] | list[str],
9298
prompt: str | Prompt | None = None,
9399
**kwargs,
94100
) -> EmbeddingsResultFuture:
@@ -108,6 +114,12 @@ def logits(
108114
-------
109115
EmbeddingsResultFuture
110116
Future object that returns the logits of the submitted sequences.
117+
118+
Note: The logits for PoET can have an extra first dimension if using ensemble
119+
prompts, where the first dimension is the number of replicates in the ensemble
120+
prompt. i.e. the shape is ``(N, L, D)`` if ``N`` > 1 else ``(L, D)`` where ``N`` is
121+
the number of replicates in the prompt, ``L`` is the length of the sequence, ``D`` is
122+
the size of the vocabulary.
111123
"""
112124
if prompt is None:
113125
prompt_id = None
@@ -317,11 +329,11 @@ def generate(
317329

318330
def fit_svd(
319331
self,
320-
prompt: str | Prompt | None = None,
321332
sequences: list[bytes] | list[str] | None = None,
322-
assay: AssayDataset | None = None,
333+
assay: AssayDataset | AssayMetadata | None = None,
323334
n_components: int = 1024,
324-
reduction: ReductionType | None = None,
335+
reduction: Reduction | ReductionType | None = None,
336+
prompt: str | Prompt | None = None,
325337
**kwargs,
326338
) -> "SVDModel":
327339
"""
@@ -365,11 +377,11 @@ def fit_svd(
365377

366378
def fit_umap(
367379
self,
368-
prompt: str | Prompt | None = None,
369380
sequences: list[bytes] | list[str] | None = None,
370-
assay: AssayDataset | None = None,
381+
assay: AssayDataset | AssayMetadata | None = None,
371382
n_components: int = 2,
372-
reduction: ReductionType = ReductionType.MEAN,
383+
reduction: Reduction | ReductionType = ReductionType.MEAN,
384+
prompt: str | Prompt | None = None,
373385
**kwargs,
374386
) -> "UMAPModel":
375387
"""
@@ -413,8 +425,11 @@ def fit_umap(
413425

414426
def fit_gp(
415427
self,
416-
assay: AssayMetadata | AssayDataset | str,
428+
assay: AssayDataset | AssayMetadata | str,
417429
properties: list[str],
430+
reduction: ReductionType,
431+
name: str | None = None,
432+
description: str | None = None,
418433
prompt: str | Prompt | None = None,
419434
**kwargs,
420435
) -> "PredictorModel":
@@ -444,6 +459,9 @@ def fit_gp(
444459
return super().fit_gp(
445460
assay=assay,
446461
properties=properties,
462+
reduction=reduction,
463+
name=name,
464+
description=description,
447465
prompt_id=prompt_id,
448466
**kwargs,
449467
)

openprotein/embeddings/poet2.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ def embed(
8686
-------
8787
EmbeddingsResultFuture
8888
A future object that returns the embeddings of the submitted sequences.
89+
90+
Note: The embeddings for PoET can have an extra first dimension if using ensemble
91+
prompts, where the first dimension is the number of replicates in the ensemble
92+
prompt. i.e. the shape is ``(N, L, D)`` if ``N`` > 1 else ``(L, D)`` where ``N`` is
93+
the number of replicates in the prompt, ``L`` is the length of the sequence, ``D`` is
94+
the dimensions of the ensemble.
8995
"""
9096
prompt_api = getattr(self.session, "prompt", None)
9197
assert isinstance(prompt_api, PromptAPI)
@@ -127,6 +133,12 @@ def logits(
127133
-------
128134
EmbeddingsResultFuture
129135
A future object that returns the logits of the submitted sequences.
136+
137+
Note: The logits for PoET can have an extra first dimension if using ensemble
138+
prompts, where the first dimension is the number of replicates in the ensemble
139+
prompt. i.e. the shape is ``(N, L, D)`` if ``N`` > 1 else ``(L, D)`` where ``N`` is
140+
the number of replicates in the prompt, ``L`` is the length of the sequence, ``D`` is
141+
the size of the vocabulary.
130142
"""
131143
prompt_api = getattr(self.session, "prompt", None)
132144
assert isinstance(prompt_api, PromptAPI)

openprotein/fold/boltz.py

Lines changed: 112 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Community-based Boltz models for complex structure prediction with ligands/dna/rna."""
22

33
import warnings
4-
from typing import Sequence
4+
from typing import Mapping, Sequence, cast
55

66
from pydantic import BaseModel, Field, TypeAdapter, model_validator
77

@@ -10,6 +10,8 @@
1010
from openprotein.common import ModelMetadata
1111
from openprotein.fold.common import normalize_inputs, serialize_input
1212
from openprotein.molecules import Complex, Ligand, Protein
13+
from openprotein.molecules.template import Template
14+
from openprotein.prompt import PromptAPI
1315

1416
from . import api
1517
from .complex import id_generator
@@ -40,7 +42,7 @@ def fold(
4042
num_steps: int = 200,
4143
step_scale: float = 1.638,
4244
use_potentials: bool = False,
43-
constraints: list[dict] | None = None,
45+
constraints: Sequence[Mapping] | None = None,
4446
**kwargs,
4547
) -> FoldResultFuture:
4648
"""
@@ -83,19 +85,9 @@ def fold(
8385

8486
# build the normalized_models from msa
8587
if isinstance(sequences, MSAFuture):
86-
id_gen = id_generator()
87-
align_api = getattr(self.session, "align", None)
88-
assert isinstance(align_api, AlignAPI)
89-
msa = sequences # rename
90-
seed = align_api.get_seed(job_id=msa.job.job_id)
91-
_proteins: dict[str, Protein] = {}
92-
for seq in seed.split(":"):
93-
protein = Protein(sequence=seq)
94-
id = next(id_gen)
95-
protein.msa = msa.id
96-
_proteins[id] = protein
97-
normalized_complexes = [Complex(chains=_proteins)]
98-
88+
normalized_complexes = [
89+
_msa_future_to_complex(session=self.session, msa=sequences)
90+
]
9991
else:
10092
normalized_complexes = normalize_inputs(sequences)
10193

@@ -139,9 +131,9 @@ def fold(
139131
num_steps: int = 200,
140132
step_scale: float = 1.638,
141133
use_potentials: bool = False,
142-
constraints: list[dict] | None = None,
143-
templates: list[dict] | None = None,
144-
properties: list[dict] | None = None,
134+
constraints: Sequence[Mapping] | None = None,
135+
templates: Sequence[Protein | Complex | Template] | None = None,
136+
properties: Sequence[Mapping] | None = None,
145137
method: str | None = None,
146138
) -> FoldResultFuture:
147139
"""
@@ -163,7 +155,7 @@ def fold(
163155
Whether or not to use potentials.
164156
constraints : list[dict] | None = None
165157
List of constraints.
166-
templates: list[dict] | None = None
158+
templates: list[Protein | Complex | Template] | None = None
167159
List of templates to use for structure prediction.
168160
properties: list[dict] | None = None
169161
List of additional properties to predict. Should match the `BoltzProperties`
@@ -180,24 +172,98 @@ def fold(
180172
Returns
181173
-------
182174
FoldResultFuture
183-
Future for the folding result.
175+
Future for the folding result.
184176
"""
185-
177+
prompt_api = getattr(self.session, "prompt", None)
178+
assert isinstance(prompt_api, PromptAPI)
179+
180+
# validate templates
181+
# mapping chain_id (to predict) to template
182+
# needs to be consistent
183+
templates_: list[Template] = []
184+
if not isinstance(sequences, MSAFuture):
185+
first_chain_id_to_template = {}
186+
for batch_idx, seq in enumerate(sequences):
187+
# validate templates and normalize to complex
188+
if isinstance(seq, str) or isinstance(seq, bytes):
189+
seq = Protein(seq)
190+
seq._assert_valid_templates()
191+
if isinstance(seq, Protein):
192+
complex = Complex({"A": seq})
193+
else:
194+
complex = seq
195+
# resolve chain-level templates
196+
for chain_id, protein in complex.get_proteins().items():
197+
# Verify same chain_id should have same templates
198+
if batch_idx == 0:
199+
first_chain_id_to_template[chain_id] = protein.templates
200+
for template in protein.templates:
201+
templates_.append(_to_template(template, chain_id=chain_id))
202+
elif first_chain_id_to_template[chain_id] != protein.templates:
203+
raise ValueError(
204+
"Expected same chain across batches to have the same templates"
205+
)
206+
# resolve complex-level templates
207+
if batch_idx == 0:
208+
first_templates = complex.templates
209+
for template in complex.templates:
210+
templates_.append(_to_template(template))
211+
elif first_templates != complex.templates:
212+
raise ValueError(
213+
"Expected templates across complexes in batch to be the same"
214+
)
215+
# method level argument
186216
if templates is not None:
187-
raise ValueError("`templates` not yet supported!")
217+
if isinstance(sequences, MSAFuture):
218+
# need to convert to complex for template validation
219+
sequences = [
220+
_msa_future_to_complex(session=self.session, msa=sequences)
221+
]
222+
for template in templates:
223+
template = _to_template(template)
224+
# validate the template for all sequences before accepting it
225+
for seq in sequences:
226+
if isinstance(seq, str) or isinstance(seq, bytes):
227+
seq = Protein(seq)
228+
template.validate_for_target(seq)
229+
templates_.append(template)
230+
231+
# resolve list of Templates into expected dict arg
232+
template_dicts: list[dict] = []
233+
# track resolved queries to reduce network calls - use id() for identity-based caching
234+
struct_id_to_query_id = {}
235+
236+
for template in templates_:
237+
# Use id() for caching - only resolve each unique structure once
238+
struct_id = id(template.template)
239+
if struct_id not in struct_id_to_query_id:
240+
struct_id_to_query_id[struct_id] = prompt_api._resolve_query(
241+
query=template.template
242+
)
243+
244+
template_dict = {"query_id": struct_id_to_query_id[struct_id]}
245+
246+
if template.mapping is not None:
247+
if isinstance(template.mapping, str):
248+
template_dict["chain_id"] = template.mapping
249+
else:
250+
template_dict["chain_id"] = list(template.mapping.values())
251+
template_dict["template_id"] = list(template.mapping.keys())
252+
253+
template_dicts.append(template_dict)
188254

189255
# validate properties
190256
if properties is not None:
191257
props = TypeAdapter(list[BoltzProperty]).validate_python(properties)
192258
# Only allow affinity for ligands, and check binder refers to a ligand chain_id (str, not list)
193259
ligand_chain_ids = set()
194-
if isinstance(sequences, list):
260+
if not isinstance(sequences, MSAFuture):
195261
for protein in sequences:
196262
if isinstance(protein, Complex):
197263
complex = protein
198-
for id, chain in complex.get_chains().items():
264+
for chain_id, chain in complex.get_chains().items():
199265
if isinstance(chain, Ligand):
200-
ligand_chain_ids.add(id)
266+
ligand_chain_ids.add(chain_id)
201267
for prop in props:
202268
if hasattr(prop, "affinity") and prop.affinity is not None:
203269
binder_id = prop.affinity.binder
@@ -214,7 +280,7 @@ def fold(
214280
step_scale=step_scale,
215281
use_potentials=use_potentials,
216282
constraints=constraints,
217-
templates=templates,
283+
templates=template_dicts or None,
218284
properties=properties,
219285
method=method,
220286
)
@@ -235,7 +301,7 @@ def fold(
235301
num_steps: int = 200,
236302
step_scale: float = 1.638,
237303
use_potentials: bool = False,
238-
constraints: list[dict] | None = None,
304+
constraints: Sequence[Mapping] | None = None,
239305
) -> FoldResultFuture:
240306
"""
241307
Request structure prediction with Boltz-1 model.
@@ -305,7 +371,7 @@ def fold(
305371
num_recycles: int = 3,
306372
num_steps: int = 200,
307373
step_scale: float = 1.638,
308-
constraints: list[dict] | None = None,
374+
constraints: Sequence[Mapping] | None = None,
309375
) -> FoldResultFuture:
310376
"""
311377
Request structure prediction with Boltz-1x model. Uses potentials with Boltz-1 model.
@@ -516,3 +582,21 @@ class BoltzAffinity(BaseModel):
516582

517583
class Config:
518584
extra = "allow" # Allow extra fields
585+
586+
587+
def _msa_future_to_complex(session: APISession, msa: MSAFuture) -> Complex:
588+
align_api = getattr(session, "align", None)
589+
assert isinstance(align_api, AlignAPI)
590+
seed = align_api.get_seed(job_id=msa.job.job_id)
591+
proteins: dict[str, Protein] = {}
592+
for chain_id, seq in zip(id_generator(), seed.split(":")):
593+
protein = Protein(sequence=seq)
594+
protein.msa = msa.id
595+
proteins[chain_id] = protein
596+
return Complex(chains=proteins)
597+
598+
599+
def _to_template(obj, chain_id: str | None = None):
600+
if not isinstance(obj, Template):
601+
obj = Template(template=obj, mapping=chain_id)
602+
return obj

0 commit comments

Comments
 (0)