Skip to content

Commit 3f97d60

Browse files
committed
Release v0.12.1
1 parent 1b2b345 commit 3f97d60

File tree

13 files changed

+346
-59
lines changed

13 files changed

+346
-59
lines changed

openprotein/embeddings/future.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ def get_item(self, sequence: str | bytes) -> np.ndarray:
9696
return api.result_decode(data)
9797

9898

99-
Score = namedtuple("Score", ["name", "sequence", "score"])
100-
SingleSiteScore = namedtuple("SingleSiteScore", ["mut_code", "score"])
99+
Score = namedtuple("Score", ["name", "sequence", "score", "query_id"])
100+
Score.__new__.__defaults__ = (None,)
101+
SingleSiteScore = namedtuple("SingleSiteScore", ["mut_code", "score", "query_id"])
102+
SingleSiteScore.__new__.__defaults__ = (None,)
101103
S = TypeVar("S", bound=Union[Score, SingleSiteScore])
102104

103105

@@ -132,13 +134,18 @@ class EmbeddingsScoreFuture(BaseScoreFuture[Score]):
132134

133135
def stream(self) -> Iterator[Score]:
134136
stream = api.request_get_score_result(session=self.session, job_id=self.id)
135-
# name, sequence, ...
136-
next(stream) # ignore header
137+
header = next(stream)
138+
has_query_id = len(header) > 0 and header[0].strip().lower() == "query_id"
137139
for line in stream:
138-
# combine scores into numpy array
139-
scores = np.array([float(s) for s in line[2:]])
140-
output = Score(name=line[0], sequence=line[1], score=scores)
141-
yield output
140+
if has_query_id:
141+
query_id = line[0] if line[0] else None
142+
name, sequence = line[1], line[2]
143+
scores = np.array([float(s) for s in line[3:]])
144+
else:
145+
query_id = None
146+
name, sequence = line[0], line[1]
147+
scores = np.array([float(s) for s in line[2:]])
148+
yield Score(name=name, sequence=sequence, score=scores, query_id=query_id)
142149

143150

144151
class EmbeddingsScoreSingleSiteFuture(BaseScoreFuture[SingleSiteScore]):
@@ -148,13 +155,18 @@ class EmbeddingsScoreSingleSiteFuture(BaseScoreFuture[SingleSiteScore]):
148155

149156
def stream(self) -> Iterator[SingleSiteScore]:
150157
stream = api.request_get_score_result(session=self.session, job_id=self.id)
151-
# name, sequence, ...
152-
next(stream) # ignore header
158+
header = next(stream)
159+
has_query_id = len(header) > 0 and header[0].strip().lower() == "query_id"
153160
for line in stream:
154-
# combine scores into numpy array
155-
scores = np.array([float(s) for s in line[1:]])
156-
output = SingleSiteScore(mut_code=line[0], score=scores)
157-
yield output
161+
if has_query_id:
162+
query_id = line[0] if line[0] else None
163+
mut_code = line[1]
164+
scores = np.array([float(s) for s in line[2:]])
165+
else:
166+
query_id = None
167+
mut_code = line[0]
168+
scores = np.array([float(s) for s in line[1:]])
169+
yield SingleSiteScore(mut_code=mut_code, score=scores, query_id=query_id)
158170

159171

160172
class EmbeddingsGenerateFuture(BaseScoreFuture[Score]):
@@ -164,17 +176,18 @@ class EmbeddingsGenerateFuture(BaseScoreFuture[Score]):
164176

165177
def stream(self) -> Iterator[Score]:
166178
stream = api.request_get_generate_result(session=self.session, job_id=self.id)
167-
# name, sequence, ...
168179
header = next(stream)
169-
has_query_id = (
170-
len(header) > 2 and header[-1].strip().lower() == "query_id"
171-
)
180+
has_query_id = len(header) > 0 and header[0].strip().lower() == "query_id"
172181
for line in stream:
173-
# combine scores into numpy array
174-
score_values = line[2:-1] if has_query_id else line[2:]
175-
scores = np.array([float(s) for s in score_values])
176-
output = Score(name=line[0], sequence=line[1], score=scores)
177-
yield output
182+
if has_query_id:
183+
query_id = line[0] if line[0] else None
184+
name, sequence = line[1], line[2]
185+
scores = np.array([float(s) for s in line[3:]])
186+
else:
187+
query_id = None
188+
name, sequence = line[0], line[1]
189+
scores = np.array([float(s) for s in line[2:]])
190+
yield Score(name=name, sequence=sequence, score=scores, query_id=query_id)
178191

179192
@property
180193
def sequences(self):

openprotein/errors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from pydantic import BaseModel
22
from requests import Response
33

4+
UPGRADE_MESSAGE = (
5+
"If this issue persists, try upgrading the client: pip install --upgrade openprotein-python"
6+
)
7+
48

59
# Errors for OpenProtein
610
class InvalidParameterError(Exception):
@@ -28,7 +32,7 @@ class APIError(Exception):
2832
"""APIError"""
2933

3034
def __init__(self, message: str):
31-
self.message = message
35+
self.message = f"{message}\n{UPGRADE_MESSAGE}"
3236
super().__init__(self.message)
3337

3438

openprotein/fold/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .future import FoldResultFuture
1818
from .minifold import MiniFoldModel
1919
from .models import FoldModel
20-
from .protenix import ProtenixModel
20+
from .protenix import ProtenixConfidence, ProtenixModel
2121
from .rosettafold3 import RosettaFold3Model
2222
from .schemas import FoldJob, FoldMetadata
2323

@@ -28,6 +28,7 @@
2828
"ESMFoldModel",
2929
"MiniFoldModel",
3030
"AlphaFold2Model",
31+
"ProtenixConfidence",
3132
"ProtenixModel",
3233
"Boltz1Model",
3334
"Boltz1xModel",

openprotein/fold/future.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020

2121
if TYPE_CHECKING:
2222
from .boltz import BoltzAffinity, BoltzConfidence
23+
from .protenix import ProtenixConfidence
2324

2425
FoldResult: typing.TypeAlias = (
25-
"Structure | np.ndarray | pd.DataFrame | BoltzAffinity | list[BoltzConfidence]"
26+
"Structure | np.ndarray | pd.DataFrame | BoltzAffinity | list[BoltzConfidence] | list[ProtenixConfidence]"
2627
)
2728

2829

@@ -307,9 +308,14 @@ def get_item(
307308

308309
data = TypeAdapter(BoltzAffinity).validate_python(data)
309310
elif key == "confidence":
310-
from .boltz import BoltzConfidence
311+
if self.model_id == "protenix":
312+
from .protenix import ProtenixConfidence
311313

312-
data = TypeAdapter(list[BoltzConfidence]).validate_python(data)
314+
data = TypeAdapter(list[ProtenixConfidence]).validate_python(data)
315+
else:
316+
from .boltz import BoltzConfidence
317+
318+
data = TypeAdapter(list[BoltzConfidence]).validate_python(data)
313319
return data # ty: ignore[invalid-return-type]
314320

315321
@typing.overload
@@ -588,7 +594,9 @@ def get_metrics(self) -> list[pd.DataFrame]:
588594
self._metrics = metrics
589595
return copy.deepcopy(self._metrics)
590596

591-
def get_confidence(self) -> list[list["BoltzConfidence"]]:
597+
def get_confidence(
598+
self,
599+
) -> "list[list[BoltzConfidence]] | list[list[ProtenixConfidence]]":
592600
"""
593601
Retrieve the confidences of the structure prediction.
594602
@@ -598,8 +606,8 @@ def get_confidence(self) -> list[list["BoltzConfidence"]]:
598606
599607
Returns
600608
-------
601-
list[list[BoltzConfidence]]
602-
List of list of BoltzConfidence objects.
609+
list[list[BoltzConfidence]] | list[list[ProtenixConfidence]]
610+
List of list of confidence objects (model-specific schema).
603611
604612
Raises
605613
------

openprotein/fold/protenix.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from collections.abc import Sequence
44

5+
from pydantic import BaseModel
6+
57
from openprotein.align import MSAFuture
68
from openprotein.base import APISession
79
from openprotein.common import ModelMetadata
@@ -19,6 +21,64 @@
1921
from .models import FoldModel
2022

2123

24+
class ProtenixConfidence(BaseModel):
25+
"""
26+
Per-sample confidence scores from a Protenix structure prediction.
27+
28+
Attributes
29+
----------
30+
ranking_score : float
31+
Composite ranking metric: ``0.8 * iptm + 0.2 * ptm - 100 * has_clash``.
32+
ptm : float
33+
Predicted TM-score for the full complex.
34+
iptm : float
35+
Interface pTM aggregated over inter-chain residue pairs.
36+
plddt : float
37+
Mean per-atom pLDDT in [0, 100].
38+
gpde : float
39+
Global PDE weighted by contact probabilities.
40+
has_clash : float
41+
Binary clash flag (1.0 if atomic clashes detected, else 0.0).
42+
num_recycles : int
43+
Number of recycling iterations used.
44+
disorder : float
45+
Disorder score (currently always 0.0).
46+
chain_ptm : list[float]
47+
Per-chain pTM scores, indexed by chain.
48+
chain_iptm : list[float]
49+
Per-chain ipTM scores.
50+
chain_plddt : list[float]
51+
Per-chain mean pLDDT scores.
52+
chain_gpde : list[float]
53+
Per-chain global PDE scores.
54+
chain_pair_iptm : list[list[float]]
55+
Chain-pair ipTM matrix.
56+
chain_pair_iptm_global : list[list[float]]
57+
Chain-pair ipTM matrix with ligand-aware weighting.
58+
chain_pair_gpde : list[list[float]]
59+
Chain-pair global PDE matrix.
60+
"""
61+
62+
ranking_score: float
63+
ptm: float
64+
iptm: float
65+
plddt: float
66+
gpde: float
67+
has_clash: float
68+
num_recycles: int
69+
disorder: float
70+
chain_ptm: list[float]
71+
chain_iptm: list[float]
72+
chain_plddt: list[float]
73+
chain_gpde: list[float]
74+
chain_pair_iptm: list[list[float]]
75+
chain_pair_iptm_global: list[list[float]]
76+
chain_pair_gpde: list[list[float]]
77+
78+
class Config:
79+
extra = "allow"
80+
81+
2282
class ProtenixModel(FoldModel):
2383
"""
2484
Class providing inference endpoints for Protenix structure prediction.

openprotein/jobs/futures.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from openprotein import config
2727
from openprotein.base import APISession
28-
from openprotein.errors import TimeoutException
28+
from openprotein.errors import APIError, TimeoutException
2929
from openprotein.jobs.schemas import Job, JobStatus
3030

3131
from . import api
@@ -389,7 +389,12 @@ def wait(
389389
time.sleep(1) # buffer for BE to register job
390390
job = self._wait_job(interval=interval, timeout=timeout, verbose=verbose)
391391
self.job = job
392-
return self.get()
392+
try:
393+
return self.get()
394+
except APIError:
395+
raise
396+
except Exception as e:
397+
raise APIError(f"Failed to retrieve results: {e}") from e
393398

394399

395400
class StreamingFuture(Future[list[V]], ABC, Generic[V]):
@@ -433,15 +438,20 @@ def get(self, verbose: bool = False, **kwargs) -> list[V]:
433438
A list containing all results from the job.
434439
435440
"""
436-
generator = self.stream(**kwargs)
437-
if verbose:
438-
total = None
439-
if hasattr(self, "__len__"):
440-
total = len(self) # type: ignore - static type checker doesnt know
441-
generator = tqdm.tqdm(
442-
generator, desc="Retrieving", total=total, position=0, mininterval=1.0
443-
)
444-
return [entry for entry in generator]
441+
try:
442+
generator = self.stream(**kwargs)
443+
if verbose:
444+
total = None
445+
if hasattr(self, "__len__"):
446+
total = len(self) # type: ignore - static type checker doesnt know
447+
generator = tqdm.tqdm(
448+
generator, desc="Retrieving", total=total, position=0, mininterval=1.0
449+
)
450+
return [entry for entry in generator]
451+
except APIError:
452+
raise
453+
except Exception as e:
454+
raise APIError(f"Failed to parse results: {e}") from e
445455

446456
def wait(
447457
self,

openprotein/prompt/models.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from openprotein import config
12
from openprotein.base import APISession
23
from openprotein.jobs import Future, JobsAPI
34
from openprotein.molecules import Complex, Protein
@@ -55,11 +56,12 @@ def __init__(
5556
)
5657
self.metadata = metadata
5758
self.session = session
59+
self.job = None # default for uploaded
5860
if self.metadata.job_id is not None:
5961
jobs_api = getattr(session, "jobs", None)
6062
assert isinstance(jobs_api, JobsAPI)
6163
job = PromptJob.create(jobs_api.get_job(job_id=self.metadata.job_id))
62-
super().__init__(session, job)
64+
self.job = job
6365

6466
def __str__(self) -> str:
6567
return str(self.metadata)
@@ -77,10 +79,15 @@ def get(self) -> list[list[Protein]]:
7779
context = api.get_prompt(session=self.session, prompt_id=str(self.id))
7880
return context
7981

80-
def _wait_job(self, **kwargs):
82+
def _wait_job(
83+
self,
84+
interval: float = config.POLLING_INTERVAL,
85+
timeout: int | None = None,
86+
verbose: bool = False,
87+
):
8188
if self.job is None:
8289
return None
83-
return super()._wait_job(**kwargs)
90+
return super()._wait_job(interval, timeout, verbose)
8491

8592
@property
8693
def id(self):

openprotein/utils/chain_id.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
import re
22
import string
33

4-
valid_id_pattern = re.compile(r"^[A-Z]{1,5}$|^\d{1,5}$")
4+
valid_id_pattern = re.compile(r"^[A-Za-z0-9]{1,5}$")
55

66

77
def is_valid_id(id_str: str) -> bool:
88
"""
9-
Check if the id_str matches the valid pattern for IDs (1-5 uppercase or 1-5 digits).
9+
Check if the id_str matches the valid pattern for IDs (1-5 uppercase or digits).
1010
"""
11-
if not id_str or len(id_str) > 5:
12-
return False
13-
return bool(valid_id_pattern.fullmatch(id_str))
11+
return bool(id_str and valid_id_pattern.fullmatch(id_str))
1412

1513

1614
def id_generator(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ markers = [
8989
"slow: Slow-running tests",
9090
"integration: Integration tests",
9191
]
92-
addopts = "-v --strict-markers --tb=short --disable-warnings"
92+
addopts = "-v --strict-markers --tb=short --disable-warnings -m 'not e2e'"
9393
timeout = 1200
9494
testpaths = ["tests"]
9595
log_cli = false

tests/e2e/test_embeddings_e2e.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,24 @@ def test_e2e_poet2_generate_with_query_fanout(session: OpenProtein):
393393
)
394394

395395

396+
@pytest.mark.e2e
397+
def test_e2e_poet2_generate_with_prompt(session: OpenProtein):
398+
"""Validate PoET2 generate with a prompt that has already reached SUCCESS."""
399+
n_sequences = 2
400+
context = ["ACDEFGHIKLMNPQRSTVWY", "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQ"]
401+
prompt = session.prompt.create_prompt(context)
402+
assert prompt.wait_until_done(timeout=TIMEOUT)
403+
404+
future = session.embedding.poet2.generate(
405+
prompt=prompt,
406+
num_samples=n_sequences,
407+
temperature=1.0,
408+
)
409+
assert future.wait_until_done(timeout=GENERATE_TIMEOUT)
410+
results = future.get()
411+
_assert_generated_sequences(results=results, expected_count=n_sequences)
412+
413+
396414
@pytest.mark.e2e
397415
def test_e2e_proteinmpnn_score_not_implemented(session: OpenProtein):
398416
with pytest.raises(NotImplementedError, match="Score not yet implemented"):

0 commit comments

Comments
 (0)