22
33from __future__ import annotations
44
5+ import os
6+ from concurrent .futures import ThreadPoolExecutor
7+
58import numpy as np
69
710_NEWLINE = b"\n "
11+ _MAX_WORKERS = 12
812
913
1014def find_matches (fasta_path : str , pattern : bytes ) -> list [tuple [str , list [int ]]]:
11- """Find every FASTA record whose sequence contains ``pattern``.
12-
13- This version assumes the benchmark-sized generated FASTA input: ASCII
14- headers, DNA sequence lines separated by ``\n ``, and no whitespace inside
15- sequence lines besides those newlines.
16- """
15+ """Find every FASTA record whose sequence contains ``pattern``."""
1716 if not pattern :
1817 return []
1918
@@ -23,28 +22,57 @@ def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]
2322 with open (fasta_path , "rb" ) as file :
2423 data = file .read ()
2524
26- matches : list [tuple [str , list [int ]]] = []
27- for record in data .split (b">" )[1 :]:
28- record_id , _ , wrapped_sequence = record .partition (_NEWLINE )
29- sequence = wrapped_sequence .replace (_NEWLINE , b"" )
30- sequence_len = len (sequence )
31- if sequence_len < pattern_len :
32- continue
33-
34- sequence_values = np .frombuffer (sequence , dtype = np .uint8 )
35- positions_mask = (
36- sequence_values [: sequence_len - pattern_len + 1 ] == pattern_values [0 ]
25+ records = data .split (b">" )[1 :]
26+ worker_count = min (_MAX_WORKERS , os .cpu_count () or 1 , len (records ))
27+ if worker_count <= 1 :
28+ return _scan_records (records , pattern_values , pattern_len )
29+
30+ chunk_size = (len (records ) + worker_count - 1 ) // worker_count
31+ chunks = [
32+ records [start : start + chunk_size ]
33+ for start in range (0 , len (records ), chunk_size )
34+ ]
35+ with ThreadPoolExecutor (max_workers = worker_count ) as executor :
36+ groups = executor .map (
37+ _scan_records ,
38+ chunks ,
39+ [pattern_values ] * len (chunks ),
40+ [pattern_len ] * len (chunks ),
3741 )
38- for pattern_index in range (1 , pattern_len ):
39- positions_mask &= (
40- sequence_values [
41- pattern_index : sequence_len - pattern_len + 1 + pattern_index
42- ]
43- == pattern_values [pattern_index ]
44- )
45-
46- positions = np .nonzero (positions_mask )[0 ]
47- if positions .size :
48- matches .append ((record_id .decode ("ascii" ), positions .tolist ()))
4942
43+ return [match for group in groups for match in group ]
44+
45+
46+ def _scan_records (
47+ records : list [bytes ], pattern_values : np .ndarray , pattern_len : int
48+ ) -> list [tuple [str , list [int ]]]:
49+ matches : list [tuple [str , list [int ]]] = []
50+ for record in records :
51+ match = _scan_record (record , pattern_values , pattern_len )
52+ if match is not None :
53+ matches .append (match )
5054 return matches
55+
56+
57+ def _scan_record (
58+ record : bytes , pattern_values : np .ndarray , pattern_len : int
59+ ) -> tuple [str , list [int ]] | None :
60+ record_id , _ , wrapped_sequence = record .partition (_NEWLINE )
61+ sequence = wrapped_sequence .replace (_NEWLINE , b"" )
62+ sequence_len = len (sequence )
63+ if sequence_len < pattern_len :
64+ return None
65+
66+ sequence_values = np .frombuffer (sequence , dtype = np .uint8 )
67+ candidate_count = sequence_len - pattern_len + 1
68+ positions_mask = sequence_values [:candidate_count ] == pattern_values [0 ]
69+ for pattern_index in range (1 , pattern_len ):
70+ positions_mask &= (
71+ sequence_values [pattern_index : candidate_count + pattern_index ]
72+ == pattern_values [pattern_index ]
73+ )
74+
75+ positions = np .nonzero (positions_mask )[0 ]
76+ if positions .size :
77+ return record_id .decode ("ascii" ), positions .tolist ()
78+ return None
0 commit comments