11from __future__ import annotations
22
3+ import mmap
34import os
45from concurrent .futures import FIRST_COMPLETED , Future , ThreadPoolExecutor , wait
56from os import PathLike
89
910Pathish = Union [str , bytes , PathLike [str ], PathLike [bytes ]]
1011
11- Record = tuple [int , str , bytearray ]
12+ # (record_index, record_start_offset, record_end_offset)
13+ Span = tuple [int , int , int ]
14+
15+ # (record_index, record_id, match_positions)
1216SearchResult = tuple [int , str , list [int ]]
1317
18+ # Baseline behavior removes spaces and newlines from sequence text.
19+ # In binary mode we also remove '\r' to match text-mode universal newlines.
20+ _DELETE_SEQUENCE_BYTES = b" \r \n "
1421
15- def _iter_fasta_records (fasta_path : Pathish ) -> Iterator [Record ]:
16- """
17- Yield FASTA records as:
1822
19- (record_index, record_id, sequence)
23+ def _default_worker_count () -> int :
24+ # Python 3.13+ may expose process_cpu_count(), which respects CPU limits.
25+ process_cpu_count = getattr (os , "process_cpu_count" , None )
26+
27+ if process_cpu_count is not None :
28+ count = process_cpu_count ()
29+ else :
30+ count = os .cpu_count ()
31+
32+ return count or 1
33+
2034
21- The sequence is accumulated as bytes, with literal spaces removed to match
22- the baseline behavior.
35+ def _iter_record_spans (mm : mmap .mmap , size : int ) -> Iterator [Span ]:
2336 """
37+ Yield FASTA record byte ranges.
2438
25- record_id : str | None = None
26- sequence = bytearray ()
27- index = 0
39+ Assumes valid FASTA-style records where headers begin with '>' at the start
40+ of a line. This is faster than splitting the whole file on b'>'.
41+ """
2842
29- with open (fasta_path , "rb" ) as f :
30- for raw_line in f :
31- if raw_line [:1 ] == b">" :
32- if record_id is not None :
33- yield index , record_id , sequence
34- index += 1
43+ if size == 0 :
44+ return
3545
36- record_id = raw_line [1 :].strip ().decode ("ascii" )
37- sequence = bytearray ()
38- continue
46+ if mm [:1 ] == b">" :
47+ start = 0
48+ else :
49+ marker = mm .find (b"\n >" )
50+ if marker < 0 :
51+ return
52+ start = marker + 1
3953
40- # Ignore preamble before the first FASTA header.
41- if record_id is None :
42- continue
54+ index = 0
4355
44- line = raw_line .rstrip (b"\r \n " )
56+ while start < size :
57+ next_marker = mm .find (b"\n >" , start + 1 )
58+ end = size if next_marker < 0 else next_marker
4559
46- # Match the baseline's `.replace(" ", "")`.
47- if b" " in line :
48- line = line .replace (b" " , b"" )
60+ yield index , start , end
4961
50- sequence . extend ( line )
62+ index += 1
5163
52- if record_id is not None :
53- yield index , record_id , sequence
64+ if next_marker < 0 :
65+ break
5466
67+ start = next_marker + 1
5568
56- def _find_overlapping_positions (sequence : bytearray , pattern : bytes ) -> list [int ]:
69+
70+ def _find_overlapping_positions (sequence : bytes , pattern : bytes ) -> list [int ]:
5771 """
58- Find all overlapping occurrences of pattern in sequence.
72+ Return every overlapping occurrence of pattern in sequence.
5973
6074 Example:
6175 sequence = b"AAAA"
6276 pattern = b"AA"
6377 result = [0, 1, 2]
6478 """
6579
66- # Preserve baseline behavior:
67- # an empty pattern matches every position from 0 through len(sequence).
80+ # Match the baseline's empty-pattern behavior.
6881 if not pattern :
6982 return list (range (len (sequence ) + 1 ))
7083
@@ -76,74 +89,107 @@ def _find_overlapping_positions(sequence: bytearray, pattern: bytes) -> list[int
7689
7790 while True :
7891 pos = find (pattern , start )
92+
7993 if pos < 0 :
8094 return positions
8195
8296 append (pos )
8397 start = pos + 1
8498
8599
86- def _search_batch (batch : list [Record ], pattern : bytes ) -> list [SearchResult ]:
100+ def _search_batch (
101+ mm : mmap .mmap ,
102+ spans : list [Span ],
103+ pattern : bytes ,
104+ ) -> list [SearchResult ]:
87105 """
88106 Worker function.
89107
90- Each worker receives a batch of records to reduce ThreadPoolExecutor
91- scheduling overhead for FASTA files with many small records .
108+ Each worker processes a batch of records. Batching is important for a file
109+ with ~10k sequences because submitting 10k individual futures is wasteful .
92110 """
93111
94- return [
95- (index , record_id , _find_overlapping_positions (sequence , pattern ))
96- for index , record_id , sequence in batch
97- ]
112+ results : list [SearchResult ] = []
113+ append_result = results .append
114+ delete_bytes = _DELETE_SEQUENCE_BYTES
115+
116+ for index , start , end in spans :
117+ header_end = mm .find (b"\n " , start , end )
118+
119+ if header_end < 0 :
120+ # Header-only record.
121+ record_id = mm [start + 1 : end ].strip ().decode ("ascii" )
122+ sequence = b""
123+ else :
124+ record_id = mm [start + 1 : header_end ].strip ().decode ("ascii" )
125+
126+ # This does sequence normalization in C:
127+ # remove line breaks and spaces from the sequence portion.
128+ sequence = mm [header_end + 1 : end ].translate (None , delete_bytes )
129+
130+ positions = _find_overlapping_positions (sequence , pattern )
131+ append_result ((index , record_id , positions ))
132+
133+ return results
98134
99135
100136def find_matches (
101137 fasta_path : Pathish ,
102138 pattern : bytes ,
103139 * ,
104140 max_workers : int | None = None ,
141+ batch_records : int = 128 ,
142+ batch_bytes : int = 8 << 20 , # 8 MiB
105143 max_pending_batches : int | None = None ,
106- batch_records : int = 64 ,
107- batch_bytes : int = 8 << 20 , # 8 MiB of sequence data
108144) -> list [tuple [str , list [int ]]]:
109145 """
110146 Find every FASTA record whose sequence contains `pattern`.
111147
112148 Returns:
113149 [(record_id, [positions...]), ...]
114150
115- Threaded design :
116- - main thread parses the FASTA file
117- - worker threads search records in parallel
118- - main thread collects results and emits them in original file order
151+ Tuned for roughly :
152+ - 512 MB input
153+ - ~10,145 records
154+ - free-threaded CPython
119155
120- This is designed for free-threaded Python. On normal GIL-enabled CPython,
121- CPU-bound speedup may be much smaller .
156+ The defaults create approximately 60-90 tasks for your file size, rather
157+ than 10,145 tiny tasks .
122158 """
123159
124160 pattern = bytes (pattern )
125161
162+ # Preserve the baseline's assumption that the pattern is ASCII text.
163+ pattern .decode ("ascii" )
164+
126165 if max_workers is None :
127- max_workers = os .cpu_count () or 1
166+ max_workers = _default_worker_count ()
167+
128168 if max_workers < 1 :
129169 raise ValueError ("max_workers must be positive" )
130170
131- if max_pending_batches is None :
132- max_pending_batches = max_workers * 2
133- if max_pending_batches < 1 :
134- raise ValueError ("max_pending_batches must be positive" )
135-
136171 if batch_records < 1 :
137172 raise ValueError ("batch_records must be positive" )
173+
138174 if batch_bytes < 1 :
139175 raise ValueError ("batch_bytes must be positive" )
140176
177+ if max_pending_batches is None :
178+ max_pending_batches = max_workers * 4
179+
180+ if max_pending_batches < 1 :
181+ raise ValueError ("max_pending_batches must be positive" )
182+
183+ size = os .path .getsize (fasta_path )
184+
185+ if size == 0 :
186+ return []
187+
141188 matches : list [tuple [str , list [int ]]] = []
142189
143190 # Completed records waiting to be emitted in file order.
144191 ready : dict [int , tuple [str , list [int ]]] = {}
145192
146- pending : set [Future [list [SearchResult ]]] = set ()
147193 next_to_emit = 0
148194
149195 def collect (done : set [Future [list [SearchResult ]]]) -> None :
@@ -153,7 +199,7 @@ def collect(done: set[Future[list[SearchResult]]]) -> None:
153199 for index , record_id , positions in future .result ():
154200 ready [index ] = (record_id , positions )
155201
156- # Emit only when the next file-order record is available .
202+ # Preserve file order even when worker batches complete out of order .
157203 while next_to_emit in ready :
158204 record_id , positions = ready .pop (next_to_emit )
159205
@@ -162,29 +208,42 @@ def collect(done: set[Future[list[SearchResult]]]) -> None:
162208
163209 next_to_emit += 1
164210
165- with ThreadPoolExecutor (max_workers = max_workers ) as executor :
166- batch : list [Record ] = []
167- batch_size = 0
168-
169- for record in _iter_fasta_records (fasta_path ):
170- batch .append (record )
171- batch_size += len (record [2 ])
211+ with open (fasta_path , "rb" ) as file :
212+ with mmap .mmap (file .fileno (), 0 , access = mmap .ACCESS_READ ) as mm :
213+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
214+ pending : set [Future [list [SearchResult ]]] = set ()
172215
173- if len (batch ) >= batch_records or batch_size >= batch_bytes :
174- pending .add (executor .submit (_search_batch , batch , pattern ))
175- batch = []
216+ batch : list [Span ] = []
176217 batch_size = 0
177218
178- # Backpressure: do not let the parser enqueue the whole file.
179- if len (pending ) >= max_pending_batches :
180- done , pending = wait (pending , return_when = FIRST_COMPLETED )
181- collect (done )
219+ for span in _iter_record_spans (mm , size ):
220+ _ , start , end = span
221+
222+ batch .append (span )
223+ batch_size += end - start
182224
183- if batch :
184- pending .add (executor .submit (_search_batch , batch , pattern ))
225+ if len ( batch ) >= batch_records or batch_size >= batch_bytes :
226+ pending .add (executor .submit (_search_batch , mm , batch , pattern ))
185227
186- while pending :
187- done , pending = wait (pending , return_when = FIRST_COMPLETED )
188- collect (done )
228+ batch = []
229+ batch_size = 0
230+
231+ # Backpressure. Avoid queueing unbounded work.
232+ if len (pending ) >= max_pending_batches :
233+ done , pending = wait (
234+ pending ,
235+ return_when = FIRST_COMPLETED ,
236+ )
237+ collect (done )
238+
239+ if batch :
240+ pending .add (executor .submit (_search_batch , mm , batch , pattern ))
241+
242+ while pending :
243+ done , pending = wait (
244+ pending ,
245+ return_when = FIRST_COMPLETED ,
246+ )
247+ collect (done )
189248
190249 return matches
0 commit comments