22
33from __future__ import annotations
44
5+ import numpy as np
6+
57_NEWLINE = b"\n "
68
79
@@ -15,21 +17,34 @@ def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]
1517 if not pattern :
1618 return []
1719
20+ pattern_values = np .frombuffer (pattern , dtype = np .uint8 )
21+ pattern_len = len (pattern )
22+
1823 with open (fasta_path , "rb" ) as file :
1924 data = file .read ()
2025
2126 matches : list [tuple [str , list [int ]]] = []
2227 for record in data .split (b">" )[1 :]:
2328 record_id , _ , wrapped_sequence = record .partition (_NEWLINE )
2429 sequence = wrapped_sequence .replace (_NEWLINE , b"" )
25-
26- positions : list [int ] = []
27- pos = sequence .find (pattern )
28- while pos != - 1 :
29- positions .append (pos )
30- pos = sequence .find (pattern , pos + 1 )
31-
32- if positions :
33- matches .append ((record_id .decode ("ascii" ), positions ))
30+ sequence_len = len (sequence )
31+ if sequence_len < pattern_len :
32+ continue
33+
34+ sequence_values = np .frombuffer (sequence , dtype = np .uint8 )
35+ positions_mask = (
36+ sequence_values [: sequence_len - pattern_len + 1 ] == pattern_values [0 ]
37+ )
38+ for pattern_index in range (1 , pattern_len ):
39+ positions_mask &= (
40+ sequence_values [
41+ pattern_index : sequence_len - pattern_len + 1 + pattern_index
42+ ]
43+ == pattern_values [pattern_index ]
44+ )
45+
46+ positions = np .nonzero (positions_mask )[0 ]
47+ if positions .size :
48+ matches .append ((record_id .decode ("ascii" ), positions .tolist ()))
3449
3550 return matches
0 commit comments