Skip to content

Commit efb63a2

Browse files
authored
[codex] Add flank-aware protein scanning (#207)
* Add flank-aware protein scanning * Relax MHCflurry parity tolerance
1 parent 9676262 commit efb63a2

8 files changed

Lines changed: 444 additions & 103 deletions

mhctools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __getattr__(name):
6363
raise AttributeError(
6464
"module %r has no attribute %r" % (__name__, name))
6565

66-
__version__ = "3.13.5"
66+
__version__ = "3.13.6"
6767

6868
__all__ = [
6969
"Prediction",

mhctools/base_commandline_predictor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,16 +326,19 @@ def _run_commands_and_collect_preds(
326326
groups[key].append(pred)
327327
return [PeptideResult(preds=tuple(preds)) for preds in groups.values()]
328328

329-
def predict(self, peptides):
329+
def predict(self, peptides, n_flanks=None, c_flanks=None):
330330
"""
331331
Predict for a list of peptide sequences.
332332
333333
Returns list of PeptideResult. When a native parse_to_preds_fn is
334334
available, parses directly to Pred objects. Otherwise falls back
335335
to converting from BindingPrediction.
336336
"""
337+
peptides, n_flank_list, c_flank_list = self._check_flank_inputs(
338+
peptides, n_flanks, c_flanks)
337339
if self.parse_to_preds_fn is None:
338-
return super().predict(peptides)
340+
return super().predict(
341+
peptides, n_flanks=n_flank_list, c_flanks=c_flank_list)
339342

340343
self._check_peptide_inputs(peptides)
341344
input_filenames = create_input_peptides_files(

mhctools/base_predictor.py

Lines changed: 186 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import logging
1414
import warnings
15-
from collections import defaultdict
15+
from collections import defaultdict, namedtuple
1616

1717
from typechecks import require_iterable_of
1818
from .allele_normalization import normalize_allele_name
@@ -23,10 +23,85 @@
2323

2424
logger = logging.getLogger(__name__)
2525

26+
PeptideContext = namedtuple(
27+
"PeptideContext",
28+
["source_sequence_name", "offset", "peptide", "n_flank", "c_flank"])
29+
30+
31+
def _normalize_sequence_dict(sequence_dict):
32+
if isinstance(sequence_dict, str):
33+
return {"seq": sequence_dict}
34+
if isinstance(sequence_dict, (list, tuple)):
35+
return {seq: seq for seq in sequence_dict}
36+
return sequence_dict
37+
38+
39+
def _check_flank_inputs(peptides, n_flanks=None, c_flanks=None):
40+
peptide_list = list(peptides)
41+
42+
def _check_flanks(name, flanks):
43+
if flanks is None:
44+
return None
45+
if isinstance(flanks, str):
46+
raise TypeError("%s must be a sequence of strings, not a string" % name)
47+
flank_list = list(flanks)
48+
require_iterable_of(flank_list, str, name)
49+
if len(flank_list) != len(peptide_list):
50+
raise ValueError(
51+
"%s must have one entry per peptide, got %d flank(s) for "
52+
"%d peptide(s)" % (name, len(flank_list), len(peptide_list)))
53+
return flank_list
54+
55+
return (
56+
peptide_list,
57+
_check_flanks("n_flanks", n_flanks),
58+
_check_flanks("c_flanks", c_flanks),
59+
)
60+
61+
62+
def _peptide_contexts(
63+
sequence_dict,
64+
peptide_lengths,
65+
flank_length=0,
66+
n_flank_length=None,
67+
c_flank_length=None):
68+
if n_flank_length is None:
69+
n_flank_length = flank_length
70+
if c_flank_length is None:
71+
c_flank_length = flank_length
72+
73+
contexts = []
74+
for name, sequence in sequence_dict.items():
75+
for peptide_length in peptide_lengths:
76+
for i in range(len(sequence) - peptide_length + 1):
77+
peptide = sequence[i:i + peptide_length]
78+
if n_flank_length or c_flank_length:
79+
n_flank = sequence[max(0, i - n_flank_length):i]
80+
c_flank = sequence[
81+
i + peptide_length:
82+
i + peptide_length + c_flank_length]
83+
else:
84+
n_flank = ""
85+
c_flank = ""
86+
contexts.append(PeptideContext(
87+
source_sequence_name=name,
88+
offset=i,
89+
peptide=peptide,
90+
n_flank=n_flank,
91+
c_flank=c_flank,
92+
))
93+
return contexts
94+
95+
2696
class BasePredictor(object):
2797
"""
2898
Base class for all MHC binding predictors.
2999
"""
100+
uses_flanking_sequences = False
101+
flank_length = 15
102+
n_flank_length = None
103+
c_flank_length = None
104+
30105
def __init__(
31106
self,
32107
alleles,
@@ -90,26 +165,85 @@ def __str__(self):
90165

91166
# --- new API ---
92167

93-
def predict(self, peptides):
168+
def predict(self, peptides, n_flanks=None, c_flanks=None):
94169
"""
95170
Predict for a list of peptide sequences.
96171
172+
n_flanks and c_flanks are accepted for a uniform flank-aware API.
173+
The default BindingPrediction-based implementation validates but
174+
ignores them; subclasses that use flanking context should override
175+
this method and set ``uses_flanking_sequences = True``.
176+
97177
Returns
98178
-------
99179
list of PeptideResult
100180
"""
101-
collection = self.predict_peptides(peptides)
181+
peptide_list, _, _ = _check_flank_inputs(
182+
peptides, n_flanks, c_flanks)
183+
collection = self.predict_peptides(peptide_list)
102184
return collection.to_peptide_preds(kind=self._default_pred_kind())
103185

104-
def predict_dataframe(self, peptides, sample_name=""):
186+
def predict_with_flanks(self, peptides, n_flanks, c_flanks):
187+
"""
188+
Optional flank-aware prediction path.
189+
190+
The default implementation validates aligned flank lists and falls
191+
back to ``predict(peptides)`` so predictors that do not use flanking
192+
context retain their existing behavior. Subclasses whose models use
193+
flanks should override this method and forward the flanks to the
194+
underlying predictor.
195+
"""
196+
peptide_list, _, _ = _check_flank_inputs(peptides, n_flanks, c_flanks)
197+
return self.predict(peptide_list)
198+
199+
def predict_dataframe(
200+
self, peptides, sample_name="", n_flanks=None, c_flanks=None):
105201
"""predict() flattened to a DataFrame."""
106202
import pandas as pd
107-
dfs = [pp.to_dataframe(sample_name) for pp in self.predict(peptides)]
203+
dfs = [
204+
pp.to_dataframe(sample_name)
205+
for pp in self.predict(
206+
peptides, n_flanks=n_flanks, c_flanks=c_flanks)
207+
]
108208
if not dfs:
109209
from .pred import COLUMNS
110210
return pd.DataFrame(columns=COLUMNS)
111211
return pd.concat(dfs, ignore_index=True)
112212

213+
def _predict_protein_flank_length(self):
214+
return max(self._predict_protein_flank_lengths())
215+
216+
def _predict_protein_flank_lengths(self):
217+
if not self.uses_flanking_sequences:
218+
return (0, 0)
219+
n_flank_length = (
220+
self.flank_length
221+
if self.n_flank_length is None else self.n_flank_length)
222+
c_flank_length = (
223+
self.flank_length
224+
if self.c_flank_length is None else self.c_flank_length)
225+
return (n_flank_length, c_flank_length)
226+
227+
def _check_flank_inputs(self, peptides, n_flanks=None, c_flanks=None):
228+
return _check_flank_inputs(peptides, n_flanks, c_flanks)
229+
230+
@staticmethod
231+
def _with_protein_location(pred, context):
232+
return Prediction(
233+
kind=pred.kind,
234+
score=pred.score,
235+
peptide=pred.peptide,
236+
allele=pred.allele,
237+
n_flank=context.n_flank,
238+
c_flank=context.c_flank,
239+
value=pred.value,
240+
percentile_rank=pred.percentile_rank,
241+
source_sequence_name=context.source_sequence_name,
242+
offset=context.offset,
243+
predictor_name=pred.predictor_name,
244+
predictor_version=pred.predictor_version,
245+
)
246+
113247
def predict_proteins(self, sequence_dict, peptide_lengths=None):
114248
"""
115249
Scan protein sequences and predict for all subsequences.
@@ -126,23 +260,44 @@ def predict_proteins(self, sequence_dict, peptide_lengths=None):
126260
-------
127261
dict mapping sequence_name -> list of PeptideResult
128262
"""
129-
if isinstance(sequence_dict, str):
130-
sequence_dict = {"seq": sequence_dict}
131-
elif isinstance(sequence_dict, (list, tuple)):
132-
sequence_dict = {seq: seq for seq in sequence_dict}
263+
sequence_dict = _normalize_sequence_dict(sequence_dict)
133264

134265
peptide_lengths = self._check_peptide_lengths(peptide_lengths)
135266

136-
peptide_set = set()
137-
peptide_to_name_offset_pairs = defaultdict(list)
138-
139-
for name, sequence in sequence_dict.items():
140-
for peptide_length in peptide_lengths:
141-
for i in range(len(sequence) - peptide_length + 1):
142-
peptide = sequence[i:i + peptide_length]
143-
peptide_set.add(peptide)
144-
peptide_to_name_offset_pairs[peptide].append((name, i))
267+
n_flank_length, c_flank_length = self._predict_protein_flank_lengths()
268+
contexts = _peptide_contexts(
269+
sequence_dict,
270+
peptide_lengths,
271+
n_flank_length=n_flank_length,
272+
c_flank_length=c_flank_length)
273+
274+
if self.uses_flanking_sequences:
275+
peptide_list = [context.peptide for context in contexts]
276+
n_flanks = [context.n_flank for context in contexts]
277+
c_flanks = [context.c_flank for context in contexts]
278+
flat_preds = self.predict_with_flanks(
279+
peptide_list,
280+
n_flanks=n_flanks,
281+
c_flanks=c_flanks)
282+
if len(flat_preds) != len(contexts):
283+
raise ValueError(
284+
"%s.predict returned %d result(s) for %d flanked "
285+
"peptide occurrence(s)" % (
286+
self.__class__.__name__, len(flat_preds),
287+
len(contexts)))
288+
results = defaultdict(list)
289+
for context, pp in zip(contexts, flat_preds):
290+
relocated = PeptideResult(preds=tuple(
291+
self._with_protein_location(p, context)
292+
for p in pp.preds
293+
))
294+
results[context.source_sequence_name].append(relocated)
295+
return dict(results)
145296

297+
peptide_set = {context.peptide for context in contexts}
298+
peptide_to_contexts = defaultdict(list)
299+
for context in contexts:
300+
peptide_to_contexts[context.peptide].append(context)
146301
peptide_list = sorted(peptide_set)
147302
flat_preds = self.predict(peptide_list)
148303

@@ -152,22 +307,12 @@ def predict_proteins(self, sequence_dict, peptide_lengths=None):
152307
if not pp.preds:
153308
continue
154309
peptide = pp.preds[0].peptide
155-
for name, offset in peptide_to_name_offset_pairs.get(peptide, []):
310+
for context in peptide_to_contexts.get(peptide, []):
156311
relocated = PeptideResult(preds=tuple(
157-
Prediction(
158-
kind=p.kind,
159-
score=p.score,
160-
peptide=p.peptide,
161-
allele=p.allele,
162-
value=p.value,
163-
percentile_rank=p.percentile_rank,
164-
source_sequence_name=name,
165-
offset=offset,
166-
predictor_name=p.predictor_name,
167-
predictor_version=p.predictor_version,
168-
) for p in pp.preds
312+
self._with_protein_location(p, context)
313+
for p in pp.preds
169314
))
170-
results[name].append(relocated)
315+
results[context.source_sequence_name].append(relocated)
171316
return dict(results)
172317

173318
def predict_proteins_dataframe(self, sequence_dict, peptide_lengths=None, sample_name=""):
@@ -312,25 +457,15 @@ def predict_subsequences(
312457
and an optional list of peptide lengths, returns a
313458
BindingPredictionCollection.
314459
"""
315-
if isinstance(sequence_dict, str):
316-
sequence_dict = {"seq": sequence_dict}
317-
elif isinstance(sequence_dict, (list, tuple)):
318-
sequence_dict = {seq: seq for seq in sequence_dict}
460+
sequence_dict = _normalize_sequence_dict(sequence_dict)
319461

320462
peptide_lengths = self._check_peptide_lengths(peptide_lengths)
321463

322-
# convert long protein sequences to set of peptides and
323-
# associated sequence name / offsets that each peptide may have come
324-
# from
325-
peptide_set = set([])
326-
peptide_to_name_offset_pairs = defaultdict(list)
327-
328-
for name, sequence in sequence_dict.items():
329-
for peptide_length in peptide_lengths:
330-
for i in range(len(sequence) - peptide_length + 1):
331-
peptide = sequence[i:i + peptide_length]
332-
peptide_set.add(peptide)
333-
peptide_to_name_offset_pairs[peptide].append((name, i))
464+
contexts = _peptide_contexts(sequence_dict, peptide_lengths)
465+
peptide_set = {context.peptide for context in contexts}
466+
peptide_to_contexts = defaultdict(list)
467+
for context in contexts:
468+
peptide_to_contexts[context.peptide].append(context)
334469
peptide_list = sorted(peptide_set)
335470

336471
binding_predictions = self.predict_peptides(peptide_list)
@@ -339,10 +474,10 @@ def predict_subsequences(
339474
results = []
340475
for binding_prediction in binding_predictions:
341476
peptide = binding_prediction.peptide
342-
for name, offset in peptide_to_name_offset_pairs[peptide]:
477+
for context in peptide_to_contexts[peptide]:
343478
results.append(binding_prediction.clone_with_updates(
344-
source_sequence_name=name,
345-
offset=offset))
479+
source_sequence_name=context.source_sequence_name,
480+
offset=context.offset))
346481
self._check_results(
347482
results,
348483
peptides=peptide_set,

0 commit comments

Comments
 (0)