Skip to content

Commit 1b2b345

Browse files
committed
Release v0.12.0
1 parent 7f22098 commit 1b2b345

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+4997
-743
lines changed

flake.nix

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
devShells.default = pkgs.mkShell {
1919
packages = [ pkgs.bashInteractive ];
2020
shellHook = ''
21-
pixi i -e dev
22-
eval $(pixi shell-hook -e dev)
21+
uv sync --group dev
22+
source .venv/bin/activate
2323
'';
2424
};
2525
}

openprotein/embeddings/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .esm import ESMModel
1111
from .poet import PoETModel
1212
from .poet2 import PoET2Model
13+
from .ablang import AbLang2Model
1314
from .schemas import (
1415
EmbeddedSequence,
1516
EmbeddingsJob,

openprotein/embeddings/ablang.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""AbLang model."""
2+
3+
from .models import EmbeddingModel
4+
5+
6+
class AbLang2Model(EmbeddingModel):
7+
"""
8+
Community AbLang2 model that targets antibodies.
9+
10+
Examples
11+
--------
12+
View specific model details (inc supported tokens) with the `?` operator.
13+
14+
.. code-block:: python
15+
16+
>>> import openprotein
17+
>>> session = openprotein.connect(username="user", password="password")
18+
>>> session.embedding.ablang2?
19+
"""
20+
21+
model_id = ["ablang2"]

openprotein/embeddings/api.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,9 +598,16 @@ def request_generate_post(
598598
body["seed"] = random_seed
599599
if kwargs.get("prompt_id"):
600600
body["prompt_id"] = kwargs["prompt_id"]
601-
if kwargs.get("query_id"):
601+
if kwargs.get("design_id"):
602+
body["design_id"] = kwargs["design_id"]
603+
query_id = kwargs.get("query_id")
604+
if query_id is not None:
602605
assert model_id != "poet", f"Model with id {model_id} does not support query"
603-
body["query_id"] = kwargs["query_id"]
606+
body["query_id"] = (
607+
list(query_id)
608+
if isinstance(query_id, list)
609+
else query_id
610+
)
604611
if "use_query_structure_in_decoder" in kwargs:
605612
body["use_query_structure_in_decoder"] = kwargs[
606613
"use_query_structure_in_decoder"

openprotein/embeddings/future.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,14 @@ class EmbeddingsGenerateFuture(BaseScoreFuture[Score]):
165165
def stream(self) -> Iterator[Score]:
166166
stream = api.request_get_generate_result(session=self.session, job_id=self.id)
167167
# name, sequence, ...
168-
next(stream) # ignore header
168+
header = next(stream)
169+
has_query_id = (
170+
len(header) > 2 and header[-1].strip().lower() == "query_id"
171+
)
169172
for line in stream:
170173
# combine scores into numpy array
171-
scores = np.array([float(s) for s in line[2:]])
174+
score_values = line[2:-1] if has_query_id else line[2:]
175+
scores = np.array([float(s) for s in score_values])
172176
output = Score(name=line[0], sequence=line[1], score=scores)
173177
yield output
174178

openprotein/embeddings/poet2.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .poet import PoETModel
2121

2222
if TYPE_CHECKING:
23+
from openprotein.models.structure_generation import StructureGenerationFuture
2324
from openprotein.predictor import PredictorModel
2425
from openprotein.svd import SVDModel
2526
from openprotein.umap import UMAPModel
@@ -290,7 +291,16 @@ def single_site(
290291
def generate(
291292
self,
292293
prompt: str | Prompt | None,
293-
query: str | bytes | Protein | Complex | Query | None = None,
294+
query: (
295+
str
296+
| bytes
297+
| Protein
298+
| Complex
299+
| Query
300+
| list[str | bytes | Protein | Complex | Query]
301+
| None
302+
) = None,
303+
design: "str | StructureGenerationFuture | None" = None,
294304
use_query_structure_in_decoder: bool = True,
295305
num_samples: int = 100,
296306
temperature: float = 1.0,
@@ -308,7 +318,7 @@ def generate(
308318
----------
309319
prompt : str or Prompt or None, optional
310320
Prompt from an align workflow to condition PoET model.
311-
query : str or bytes or Protein or Complex or Query or None, optional
321+
query : str or bytes or Protein or Complex or Query or list of these or None, optional
312322
Query to use with prompt.
313323
use_query_structure_in_decoder : bool, optional
314324
Whether to use query structure in decoder. Default is True.
@@ -340,9 +350,14 @@ def generate(
340350
EmbeddingsGenerateFuture
341351
A future object representing the status and information about the generation job.
342352
"""
353+
from openprotein.models.structure_generation import StructureGenerationFuture
354+
343355
prompt_api = getattr(self.session, "prompt", None)
344356
assert isinstance(prompt_api, PromptAPI)
345-
query_id = prompt_api._resolve_query(query=query)
357+
query_id = prompt_api._resolve_query(query=query) if query is not None else None
358+
design_id = (
359+
design.job_id if isinstance(design, StructureGenerationFuture) else design
360+
)
346361
if ensemble_weights is not None:
347362
# NB: for now, ensemble_method is None -> ensemble_method == "arithmetic"
348363
if ensemble_method is None or (ensemble_method == "arithmetic"):
@@ -364,6 +379,7 @@ def generate(
364379
max_length=max_length,
365380
seed=seed,
366381
query_id=query_id,
382+
design_id=design_id,
367383
use_query_structure_in_decoder=use_query_structure_in_decoder,
368384
ensemble_weights=ensemble_weights,
369385
ensemble_method=ensemble_method,

openprotein/fold/__init__.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
"""
22
Fold module for predicting structures on OpenProtein.
3-
4-
isort:skip_file
53
"""
64

7-
from .schemas import FoldJob, FoldMetadata
8-
from .models import FoldModel
9-
from .esmfold import ESMFoldModel
10-
from .minifold import MiniFoldModel
115
from .alphafold2 import AlphaFold2Model
126
from .boltz import (
137
Boltz1Model,
@@ -18,6 +12,31 @@
1812
BoltzConstraint,
1913
BoltzProperty,
2014
)
21-
from .rosettafold3 import RosettaFold3Model
22-
from .future import FoldResultFuture
15+
from .esmfold import ESMFoldModel
2316
from .fold import FoldAPI
17+
from .future import FoldResultFuture
18+
from .minifold import MiniFoldModel
19+
from .models import FoldModel
20+
from .protenix import ProtenixModel
21+
from .rosettafold3 import RosettaFold3Model
22+
from .schemas import FoldJob, FoldMetadata
23+
24+
__all__ = [
25+
"FoldJob",
26+
"FoldMetadata",
27+
"FoldModel",
28+
"ESMFoldModel",
29+
"MiniFoldModel",
30+
"AlphaFold2Model",
31+
"ProtenixModel",
32+
"Boltz1Model",
33+
"Boltz1xModel",
34+
"Boltz2Model",
35+
"BoltzAffinity",
36+
"BoltzConfidence",
37+
"BoltzConstraint",
38+
"BoltzProperty",
39+
"RosettaFold3Model",
40+
"FoldResultFuture",
41+
"FoldAPI",
42+
]

openprotein/fold/alphafold2.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
"""Community-based AlphaFold 2 model running using ColabFold."""
22

3-
import io
43
import warnings
5-
from typing import Any, Sequence
4+
from typing import Sequence
65

7-
from openprotein.align import AlignAPI, MSAFuture
6+
from openprotein.align import MSAFuture
87
from openprotein.base import APISession
98
from openprotein.common import ModelMetadata
10-
from openprotein.fold.common import normalize_inputs, serialize_input
11-
from openprotein.fold.complex import id_generator
12-
from openprotein.molecules import Protein, DNA, RNA, Ligand, Complex
9+
from openprotein.fold.common import (
10+
msa_future_to_complex,
11+
normalize_inputs,
12+
serialize_input,
13+
)
14+
from openprotein.molecules import DNA, RNA, Complex, Ligand, Protein
1315

1416
from . import api
1517
from .future import FoldResultFuture
@@ -33,7 +35,7 @@ def __init__(
3335

3436
def fold(
3537
self,
36-
sequences: Sequence[Complex | Protein | str] | MSAFuture | None = None,
38+
sequences: Sequence[Complex | Protein | str | bytes] | MSAFuture,
3739
num_recycles: int | None = None,
3840
num_models: int = 1,
3941
num_relax: int = 0,
@@ -44,7 +46,7 @@ def fold(
4446
4547
Parameters
4648
----------
47-
sequences : List[Complex | Protein | str] | MSAFuture
49+
sequences : Sequence[Complex | Protein | str | bytes] | MSAFuture
4850
List of protein sequences to include in folded output. `Protein` objects must be tagged with an `msa`, which can be a `Protein.single_sequence_mode` for single sequence mode. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
4951
num_recycles : int
5052
number of times to recycle models
@@ -57,7 +59,6 @@ def fold(
5759
-------
5860
job : Job
5961
"""
60-
from openprotein.align import AlignAPI
6162

6263
if "msa" in kwargs:
6364
warnings.warn(
@@ -71,18 +72,7 @@ def fold(
7172

7273
# build the normalized_models from msa
7374
if isinstance(sequences, MSAFuture):
74-
id_gen = id_generator()
75-
align_api = getattr(self.session, "align", None)
76-
assert isinstance(align_api, AlignAPI)
77-
msa = sequences # rename
78-
seed = align_api.get_seed(job_id=msa.job.job_id)
79-
_proteins: dict[str, Protein] = {}
80-
for seq in seed.split(":"):
81-
protein = Protein(sequence=seq)
82-
id = next(id_gen)
83-
protein.msa = msa.id
84-
_proteins[id] = protein
85-
normalized_complexes = [Complex(chains=_proteins)]
75+
normalized_complexes = [msa_future_to_complex(self.session, sequences)]
8676

8777
else:
8878
normalized_complexes = normalize_inputs(sequences)

openprotein/fold/api.py

Lines changed: 10 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -199,93 +199,21 @@ def fold_get_extra_result(
199199
The result as a numpy array (for "pae", "pde", "plddt") or a list of dictionaries (for "confidence", "affinity").
200200
"""
201201
if key in {"pae", "pde", "plddt", "ptm"}:
202-
formatter = lambda response: np.load(io.BytesIO(response.content))
203-
elif key in {"confidence", "affinity"}:
204-
formatter = lambda response: response.json()
205-
elif key in {"score", "metrics"}:
206-
import pandas as pd
207-
208-
formatter = lambda response: pd.read_csv(io.StringIO(response.content.decode()))
209-
else:
210-
raise ValueError(f"Unexpected key: {key}")
211-
endpoint = PATH_PREFIX + f"/{job_id}/{sequence_or_index}/{key}"
212-
try:
213-
response = session.get(
214-
endpoint,
215-
)
216-
except HTTPError as e:
217-
if e.status_code == 400 and key == "affinity":
218-
raise ValueError("affinity not found for request") from None
219-
raise e
220-
output = formatter(response)
221-
return output
222-
223-
224-
def fold_get_complex_result(
225-
session: APISession, job_id: str, format: Literal["pdb", "mmcif"]
226-
) -> bytes:
227-
"""
228-
Get encoded result for a complex from the request ID.
229202

230-
Parameters
231-
----------
232-
session : APISession
233-
Session object for API communication.
234-
job_id : str
235-
Job ID to retrieve results from.
236-
format : {'pdb', 'mmcif'}
237-
Format of the result.
238-
239-
Returns
240-
-------
241-
bytes
242-
Encoded result for the complex.
243-
"""
244-
endpoint = PATH_PREFIX + f"/{job_id}/complex"
245-
response = session.get(
246-
endpoint,
247-
params={
248-
"format": format,
249-
},
250-
)
251-
return response.content
252-
253-
254-
def fold_get_complex_extra_result(
255-
session: APISession,
256-
job_id: str,
257-
key: Literal[
258-
"pae", "pde", "plddt", "ptm", "confidence", "affinity", "score", "metrics"
259-
],
260-
) -> "np.ndarray | list[dict] | pd.DataFrame":
261-
"""
262-
Get extra result for a complex from the request ID.
263-
264-
Parameters
265-
----------
266-
session : APISession
267-
Session object for API communication.
268-
job_id : str
269-
Job ID to retrieve results from.
270-
key : {'pae', 'pde', 'plddt', 'ptm', 'confidence', 'affinity', 'score', 'metrics'}
271-
The type of result to retrieve.
272-
273-
Returns
274-
-------
275-
numpy.ndarray or list of dict
276-
The result as a numpy array (for "pae", "pde", "plddt") or a list of dictionaries (for "confidence", "affinity").
277-
"""
278-
if key in {"pae", "pde", "plddt", "ptm"}:
279-
formatter = lambda response: np.load(io.BytesIO(response.content))
203+
def formatter(response):
204+
return np.load(io.BytesIO(response.content))
280205
elif key in {"confidence", "affinity"}:
281-
formatter = lambda response: response.json()
206+
207+
def formatter(response):
208+
return response.json()
282209
elif key in {"score", "metrics"}:
283210
import pandas as pd
284211

285-
formatter = lambda response: pd.read_csv(io.StringIO(response.content.decode()))
212+
def formatter(response):
213+
return pd.read_csv(io.StringIO(response.content.decode()))
286214
else:
287215
raise ValueError(f"Unexpected key: {key}")
288-
endpoint = PATH_PREFIX + f"/{job_id}/complex/{key}"
216+
endpoint = PATH_PREFIX + f"/{job_id}/{sequence_or_index}/{key}"
289217
try:
290218
response = session.get(
291219
endpoint,
@@ -321,28 +249,8 @@ def fold_models_post(
321249
The outer list represents the batch of requests, and the inner
322250
list represents the complex, with each item in the list being
323251
an entity in that complex. A monomer would thus be a single item.
324-
num_recycles : int, optional
325-
Number of recycles for structure prediction.
326-
num_models : int, optional
327-
Number of models to generate.
328-
num_relax : int, optional
329-
Number of relaxation steps.
330-
use_potentials : bool, optional
331-
Whether to use potentials.
332-
diffusion_samples : int, optional
333-
Number of diffusion samples (boltz).
334-
recycling_steps : int, optional
335-
Number of recycling steps (boltz).
336-
sampling_steps : int, optional
337-
Number of sampling steps (boltz).
338-
step_scale : float, optional
339-
Step scale (boltz).
340-
constraints : dict, optional
341-
Constraints to apply.
342-
templates : list, optional
343-
Templates to use.
344-
properties : dict, optional
345-
Additional properties.
252+
**kwargs
253+
Additional keyword arguments to be sent with POST body.
346254
347255
Returns
348256
-------

0 commit comments

Comments
 (0)