@@ -16,16 +16,17 @@ def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]
1616 if not pattern :
1717 return []
1818
19- pattern_values = np .frombuffer (pattern , dtype = np .uint8 )
2019 pattern_len = len (pattern )
20+ pattern_prefix = np .frombuffer (pattern [:4 ], dtype = np .uint32 )[0 ]
21+ pattern_suffix = np .frombuffer (pattern [4 :], dtype = np .uint32 )[0 ]
2122
2223 with open (fasta_path , "rb" ) as file :
2324 data = file .read ()
2425
2526 records = data .split (b">" )[1 :]
2627 worker_count = min (_MAX_WORKERS , os .cpu_count () or 1 , len (records ))
2728 if worker_count <= 1 :
28- return _scan_records (records , pattern_values , pattern_len )
29+ return _scan_records (records , pattern_prefix , pattern_suffix , pattern_len )
2930
3031 chunk_size = (len (records ) + worker_count - 1 ) // worker_count
3132 chunks = [
@@ -36,43 +37,58 @@ def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]
3637 groups = executor .map (
3738 _scan_records ,
3839 chunks ,
39- [pattern_values ] * len (chunks ),
40+ [pattern_prefix ] * len (chunks ),
41+ [pattern_suffix ] * len (chunks ),
4042 [pattern_len ] * len (chunks ),
4143 )
4244
4345 return [match for group in groups for match in group ]
4446
4547
4648def _scan_records (
47- records : list [bytes ], pattern_values : np .ndarray , pattern_len : int
49+ records : list [bytes ],
50+ pattern_prefix : np .uint32 ,
51+ pattern_suffix : np .uint32 ,
52+ pattern_len : int ,
4853) -> list [tuple [str , list [int ]]]:
4954 matches : list [tuple [str , list [int ]]] = []
5055 for record in records :
51- match = _scan_record (record , pattern_values , pattern_len )
56+ match = _scan_record (record , pattern_prefix , pattern_suffix , pattern_len )
5257 if match is not None :
5358 matches .append (match )
5459 return matches
5560
5661
5762def _scan_record (
58- record : bytes , pattern_values : np .ndarray , pattern_len : int
63+ record : bytes ,
64+ pattern_prefix : np .uint32 ,
65+ pattern_suffix : np .uint32 ,
66+ pattern_len : int ,
5967) -> tuple [str , list [int ]] | None :
6068 record_id , _ , wrapped_sequence = record .partition (_NEWLINE )
6169 sequence = wrapped_sequence .replace (_NEWLINE , b"" )
6270 sequence_len = len (sequence )
6371 if sequence_len < pattern_len :
6472 return None
6573
66- sequence_values = np .frombuffer (sequence , dtype = np .uint8 )
6774 candidate_count = sequence_len - pattern_len + 1
68- positions_mask = sequence_values [:candidate_count ] == pattern_values [0 ]
69- for pattern_index in range (1 , pattern_len ):
70- positions_mask &= (
71- sequence_values [pattern_index : candidate_count + pattern_index ]
72- == pattern_values [pattern_index ]
73- )
75+ prefixes = np .ndarray (
76+ shape = (candidate_count ,),
77+ dtype = np .uint32 ,
78+ buffer = sequence ,
79+ strides = (1 ,),
80+ )
81+ candidates = np .nonzero (prefixes == pattern_prefix )[0 ]
82+ if not candidates .size :
83+ return None
7484
75- positions = np .nonzero (positions_mask )[0 ]
85+ suffixes = np .ndarray (
86+ shape = (candidate_count ,),
87+ dtype = np .uint32 ,
88+ buffer = memoryview (sequence )[4 :],
89+ strides = (1 ,),
90+ )
91+ positions = candidates [suffixes [candidates ] == pattern_suffix ]
7692 if positions .size :
7793 return record_id .decode ("ascii" ), positions .tolist ()
7894 return None
0 commit comments