1212
1313import logging
1414import warnings
15- from collections import defaultdict
15+ from collections import defaultdict , namedtuple
1616
1717from typechecks import require_iterable_of
1818from .allele_normalization import normalize_allele_name
2323
2424logger = logging .getLogger (__name__ )
2525
26+ PeptideContext = namedtuple (
27+ "PeptideContext" ,
28+ ["source_sequence_name" , "offset" , "peptide" , "n_flank" , "c_flank" ])
29+
30+
31+ def _normalize_sequence_dict (sequence_dict ):
32+ if isinstance (sequence_dict , str ):
33+ return {"seq" : sequence_dict }
34+ if isinstance (sequence_dict , (list , tuple )):
35+ return {seq : seq for seq in sequence_dict }
36+ return sequence_dict
37+
38+
39+ def _check_flank_inputs (peptides , n_flanks = None , c_flanks = None ):
40+ peptide_list = list (peptides )
41+
42+ def _check_flanks (name , flanks ):
43+ if flanks is None :
44+ return None
45+ if isinstance (flanks , str ):
46+ raise TypeError ("%s must be a sequence of strings, not a string" % name )
47+ flank_list = list (flanks )
48+ require_iterable_of (flank_list , str , name )
49+ if len (flank_list ) != len (peptide_list ):
50+ raise ValueError (
51+ "%s must have one entry per peptide, got %d flank(s) for "
52+ "%d peptide(s)" % (name , len (flank_list ), len (peptide_list )))
53+ return flank_list
54+
55+ return (
56+ peptide_list ,
57+ _check_flanks ("n_flanks" , n_flanks ),
58+ _check_flanks ("c_flanks" , c_flanks ),
59+ )
60+
61+
62+ def _peptide_contexts (
63+ sequence_dict ,
64+ peptide_lengths ,
65+ flank_length = 0 ,
66+ n_flank_length = None ,
67+ c_flank_length = None ):
68+ if n_flank_length is None :
69+ n_flank_length = flank_length
70+ if c_flank_length is None :
71+ c_flank_length = flank_length
72+
73+ contexts = []
74+ for name , sequence in sequence_dict .items ():
75+ for peptide_length in peptide_lengths :
76+ for i in range (len (sequence ) - peptide_length + 1 ):
77+ peptide = sequence [i :i + peptide_length ]
78+ if n_flank_length or c_flank_length :
79+ n_flank = sequence [max (0 , i - n_flank_length ):i ]
80+ c_flank = sequence [
81+ i + peptide_length :
82+ i + peptide_length + c_flank_length ]
83+ else :
84+ n_flank = ""
85+ c_flank = ""
86+ contexts .append (PeptideContext (
87+ source_sequence_name = name ,
88+ offset = i ,
89+ peptide = peptide ,
90+ n_flank = n_flank ,
91+ c_flank = c_flank ,
92+ ))
93+ return contexts
94+
95+
2696class BasePredictor (object ):
2797 """
2898 Base class for all MHC binding predictors.
2999 """
100+ uses_flanking_sequences = False
101+ flank_length = 15
102+ n_flank_length = None
103+ c_flank_length = None
104+
30105 def __init__ (
31106 self ,
32107 alleles ,
@@ -90,26 +165,85 @@ def __str__(self):
90165
91166 # --- new API ---
92167
93- def predict (self , peptides ):
168+ def predict (self , peptides , n_flanks = None , c_flanks = None ):
94169 """
95170 Predict for a list of peptide sequences.
96171
172+ n_flanks and c_flanks are accepted for a uniform flank-aware API.
173+ The default BindingPrediction-based implementation validates but
174+ ignores them; subclasses that use flanking context should override
175+ this method and set ``uses_flanking_sequences = True``.
176+
97177 Returns
98178 -------
99179 list of PeptideResult
100180 """
101- collection = self .predict_peptides (peptides )
181+ peptide_list , _ , _ = _check_flank_inputs (
182+ peptides , n_flanks , c_flanks )
183+ collection = self .predict_peptides (peptide_list )
102184 return collection .to_peptide_preds (kind = self ._default_pred_kind ())
103185
104- def predict_dataframe (self , peptides , sample_name = "" ):
186+ def predict_with_flanks (self , peptides , n_flanks , c_flanks ):
187+ """
188+ Optional flank-aware prediction path.
189+
190+ The default implementation validates aligned flank lists and falls
191+ back to ``predict(peptides)`` so predictors that do not use flanking
192+ context retain their existing behavior. Subclasses whose models use
193+ flanks should override this method and forward the flanks to the
194+ underlying predictor.
195+ """
196+ peptide_list , _ , _ = _check_flank_inputs (peptides , n_flanks , c_flanks )
197+ return self .predict (peptide_list )
198+
199+ def predict_dataframe (
200+ self , peptides , sample_name = "" , n_flanks = None , c_flanks = None ):
105201 """predict() flattened to a DataFrame."""
106202 import pandas as pd
107- dfs = [pp .to_dataframe (sample_name ) for pp in self .predict (peptides )]
203+ dfs = [
204+ pp .to_dataframe (sample_name )
205+ for pp in self .predict (
206+ peptides , n_flanks = n_flanks , c_flanks = c_flanks )
207+ ]
108208 if not dfs :
109209 from .pred import COLUMNS
110210 return pd .DataFrame (columns = COLUMNS )
111211 return pd .concat (dfs , ignore_index = True )
112212
213+ def _predict_protein_flank_length (self ):
214+ return max (self ._predict_protein_flank_lengths ())
215+
216+ def _predict_protein_flank_lengths (self ):
217+ if not self .uses_flanking_sequences :
218+ return (0 , 0 )
219+ n_flank_length = (
220+ self .flank_length
221+ if self .n_flank_length is None else self .n_flank_length )
222+ c_flank_length = (
223+ self .flank_length
224+ if self .c_flank_length is None else self .c_flank_length )
225+ return (n_flank_length , c_flank_length )
226+
227+ def _check_flank_inputs (self , peptides , n_flanks = None , c_flanks = None ):
228+ return _check_flank_inputs (peptides , n_flanks , c_flanks )
229+
230+ @staticmethod
231+ def _with_protein_location (pred , context ):
232+ return Prediction (
233+ kind = pred .kind ,
234+ score = pred .score ,
235+ peptide = pred .peptide ,
236+ allele = pred .allele ,
237+ n_flank = context .n_flank ,
238+ c_flank = context .c_flank ,
239+ value = pred .value ,
240+ percentile_rank = pred .percentile_rank ,
241+ source_sequence_name = context .source_sequence_name ,
242+ offset = context .offset ,
243+ predictor_name = pred .predictor_name ,
244+ predictor_version = pred .predictor_version ,
245+ )
246+
113247 def predict_proteins (self , sequence_dict , peptide_lengths = None ):
114248 """
115249 Scan protein sequences and predict for all subsequences.
@@ -126,23 +260,44 @@ def predict_proteins(self, sequence_dict, peptide_lengths=None):
126260 -------
127261 dict mapping sequence_name -> list of PeptideResult
128262 """
129- if isinstance (sequence_dict , str ):
130- sequence_dict = {"seq" : sequence_dict }
131- elif isinstance (sequence_dict , (list , tuple )):
132- sequence_dict = {seq : seq for seq in sequence_dict }
263+ sequence_dict = _normalize_sequence_dict (sequence_dict )
133264
134265 peptide_lengths = self ._check_peptide_lengths (peptide_lengths )
135266
136- peptide_set = set ()
137- peptide_to_name_offset_pairs = defaultdict (list )
138-
139- for name , sequence in sequence_dict .items ():
140- for peptide_length in peptide_lengths :
141- for i in range (len (sequence ) - peptide_length + 1 ):
142- peptide = sequence [i :i + peptide_length ]
143- peptide_set .add (peptide )
144- peptide_to_name_offset_pairs [peptide ].append ((name , i ))
267+ n_flank_length , c_flank_length = self ._predict_protein_flank_lengths ()
268+ contexts = _peptide_contexts (
269+ sequence_dict ,
270+ peptide_lengths ,
271+ n_flank_length = n_flank_length ,
272+ c_flank_length = c_flank_length )
273+
274+ if self .uses_flanking_sequences :
275+ peptide_list = [context .peptide for context in contexts ]
276+ n_flanks = [context .n_flank for context in contexts ]
277+ c_flanks = [context .c_flank for context in contexts ]
278+ flat_preds = self .predict_with_flanks (
279+ peptide_list ,
280+ n_flanks = n_flanks ,
281+ c_flanks = c_flanks )
282+ if len (flat_preds ) != len (contexts ):
283+ raise ValueError (
284+ "%s.predict returned %d result(s) for %d flanked "
285+ "peptide occurrence(s)" % (
286+ self .__class__ .__name__ , len (flat_preds ),
287+ len (contexts )))
288+ results = defaultdict (list )
289+ for context , pp in zip (contexts , flat_preds ):
290+ relocated = PeptideResult (preds = tuple (
291+ self ._with_protein_location (p , context )
292+ for p in pp .preds
293+ ))
294+ results [context .source_sequence_name ].append (relocated )
295+ return dict (results )
145296
297+ peptide_set = {context .peptide for context in contexts }
298+ peptide_to_contexts = defaultdict (list )
299+ for context in contexts :
300+ peptide_to_contexts [context .peptide ].append (context )
146301 peptide_list = sorted (peptide_set )
147302 flat_preds = self .predict (peptide_list )
148303
@@ -152,22 +307,12 @@ def predict_proteins(self, sequence_dict, peptide_lengths=None):
152307 if not pp .preds :
153308 continue
154309 peptide = pp .preds [0 ].peptide
155- for name , offset in peptide_to_name_offset_pairs .get (peptide , []):
310+ for context in peptide_to_contexts .get (peptide , []):
156311 relocated = PeptideResult (preds = tuple (
157- Prediction (
158- kind = p .kind ,
159- score = p .score ,
160- peptide = p .peptide ,
161- allele = p .allele ,
162- value = p .value ,
163- percentile_rank = p .percentile_rank ,
164- source_sequence_name = name ,
165- offset = offset ,
166- predictor_name = p .predictor_name ,
167- predictor_version = p .predictor_version ,
168- ) for p in pp .preds
312+ self ._with_protein_location (p , context )
313+ for p in pp .preds
169314 ))
170- results [name ].append (relocated )
315+ results [context . source_sequence_name ].append (relocated )
171316 return dict (results )
172317
173318 def predict_proteins_dataframe (self , sequence_dict , peptide_lengths = None , sample_name = "" ):
@@ -312,25 +457,15 @@ def predict_subsequences(
312457 and an optional list of peptide lengths, returns a
313458 BindingPredictionCollection.
314459 """
315- if isinstance (sequence_dict , str ):
316- sequence_dict = {"seq" : sequence_dict }
317- elif isinstance (sequence_dict , (list , tuple )):
318- sequence_dict = {seq : seq for seq in sequence_dict }
460+ sequence_dict = _normalize_sequence_dict (sequence_dict )
319461
320462 peptide_lengths = self ._check_peptide_lengths (peptide_lengths )
321463
322- # convert long protein sequences to set of peptides and
323- # associated sequence name / offsets that each peptide may have come
324- # from
325- peptide_set = set ([])
326- peptide_to_name_offset_pairs = defaultdict (list )
327-
328- for name , sequence in sequence_dict .items ():
329- for peptide_length in peptide_lengths :
330- for i in range (len (sequence ) - peptide_length + 1 ):
331- peptide = sequence [i :i + peptide_length ]
332- peptide_set .add (peptide )
333- peptide_to_name_offset_pairs [peptide ].append ((name , i ))
464+ contexts = _peptide_contexts (sequence_dict , peptide_lengths )
465+ peptide_set = {context .peptide for context in contexts }
466+ peptide_to_contexts = defaultdict (list )
467+ for context in contexts :
468+ peptide_to_contexts [context .peptide ].append (context )
334469 peptide_list = sorted (peptide_set )
335470
336471 binding_predictions = self .predict_peptides (peptide_list )
@@ -339,10 +474,10 @@ def predict_subsequences(
339474 results = []
340475 for binding_prediction in binding_predictions :
341476 peptide = binding_prediction .peptide
342- for name , offset in peptide_to_name_offset_pairs [peptide ]:
477+ for context in peptide_to_contexts [peptide ]:
343478 results .append (binding_prediction .clone_with_updates (
344- source_sequence_name = name ,
345- offset = offset ))
479+ source_sequence_name = context . source_sequence_name ,
480+ offset = context . offset ))
346481 self ._check_results (
347482 results ,
348483 peptides = peptide_set ,
0 commit comments