|
5 | 5 | own faster implementation. |
6 | 6 | """ |
7 | 7 |
|
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +import os |
8 | 11 | from concurrent.futures import ThreadPoolExecutor |
9 | | -from threading import Thread |
10 | | - |
11 | | - |
12 | | -def find_matches_in_sequence( |
13 | | - record_id: str, |
14 | | - sequence: str, |
15 | | - pattern_str: str, |
16 | | - matches: list[tuple[str, list[int]]], |
17 | | -): |
18 | | - """Find matches in a single sequence and append to the shared matches list.""" |
19 | | - positions: list[int] = [] |
20 | | - start = 0 |
21 | | - while True: |
22 | | - pos = sequence.find(pattern_str, start) |
23 | | - if pos == -1: |
24 | | - break |
25 | | - positions.append(pos) |
26 | | - start = pos + 1 |
27 | | - |
28 | | - if positions: |
29 | | - matches.append((record_id, positions)) |
30 | | - |
31 | | - |
32 | | -def find_matches_many_threads( |
33 | | - fasta_path: str, pattern: bytes |
34 | | -) -> list[tuple[str, list[int]]]: |
35 | | - # Step 1: read the whole FASTA file as text and decode the pattern so the |
36 | | - # search below can use a single ``str`` API. |
37 | | - pattern_str = pattern.decode("ascii") |
38 | | - with open(fasta_path, "r") as f: |
39 | | - text = f.read() |
40 | | - |
41 | | - matches: list[tuple[str, list[int]]] = [] |
42 | | - |
43 | | - # Preprocess the sequences |
44 | | - sequences = [] |
45 | | - for record in text.split(">"): |
46 | | - if not record.strip(): |
47 | | - continue |
48 | | - lines = record.split("\n") |
49 | | - record_id = lines[0].strip() |
50 | | - sequence = "".join(lines[1:]).replace(" ", "") |
51 | | - sequences.append((record_id, sequence)) |
52 | | - threads = [] |
53 | | - for record_id, sequence in sequences: |
54 | | - thread = Thread( |
55 | | - target=find_matches_in_sequence, |
56 | | - args=(record_id, sequence, pattern_str, matches), |
57 | | - ) |
58 | | - thread.start() |
59 | | - threads.append(thread) |
60 | | - # Wait for all threads to finish |
61 | | - print(f"Waiting for {len(threads)} threads to finish...") |
62 | | - for thread in threads: |
63 | | - thread.join() |
64 | | - |
65 | | - return matches |
| 12 | + |
| 13 | +_NL = 0x0A # b"\n" |
66 | 14 |
|
67 | 15 |
|
68 | 16 | def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]: |
69 | | - # Step 1: read the whole FASTA file as text and decode the pattern so the |
70 | | - # search below can use a single ``str`` API. |
71 | | - pattern_str = pattern.decode("ascii") |
72 | | - with open(fasta_path, "r") as f: |
73 | | - text = f.read() |
74 | | - |
75 | | - matches: list[tuple[str, list[int]]] = [] |
76 | | - |
77 | | - # Preprocess the sequences |
78 | | - sequences = [] |
79 | | - for record in text.split(">"): |
80 | | - if not record.strip(): |
81 | | - continue |
82 | | - lines = record.split("\n") |
83 | | - record_id = lines[0].strip() |
84 | | - sequence = "".join(lines[1:]).replace(" ", "") |
85 | | - sequences.append((record_id, sequence)) |
86 | | - |
87 | | - # Create a pool of threads |
88 | | - pool = ThreadPoolExecutor(max_workers=16) |
89 | | - |
90 | | - for record_id, sequence in sequences: |
91 | | - pool.submit( |
92 | | - find_matches_in_sequence, |
93 | | - record_id, |
94 | | - sequence, |
95 | | - pattern_str, |
96 | | - matches, |
97 | | - ) |
98 | | - # Or |
99 | | - # pool.map( |
100 | | - # lambda args: find_matches_in_sequence(*args), |
101 | | - # [ |
102 | | - # (record_id, sequence, pattern_str, matches) |
103 | | - # for record_id, sequence in sequences |
104 | | - # ], |
105 | | - # ) |
106 | | - # Wait for all threads to finish |
107 | | - pool.shutdown(wait=True) |
108 | | - |
109 | | - return matches |
| 17 | + with open(fasta_path, "rb") as f: |
| 18 | + data = f.read() |
| 19 | + |
| 20 | + # Step 1: locate every record start. A record starts with ``>`` either at |
| 21 | + # offset 0 or immediately after a ``\n``. |
| 22 | + starts: list[int] = [] |
| 23 | + i = 0 |
| 24 | + while True: |
| 25 | + p = data.find(b">", i) |
| 26 | + if p == -1: |
| 27 | + break |
| 28 | + if p == 0 or data[p - 1] == _NL: |
| 29 | + starts.append(p) |
| 30 | + i = p + 1 |
| 31 | + starts.append(len(data)) # sentinel marking the end of the last record. |
| 32 | + |
| 33 | + num_records = len(starts) - 1 |
| 34 | + if num_records <= 0: |
| 35 | + return [] |
| 36 | + |
| 37 | + # Step 2: parallel scan. Choose enough batches to keep workers balanced |
| 38 | + # even when record sizes vary. |
| 39 | + n_workers = max(1, os.cpu_count() or 1) |
| 40 | + batches = max(1, n_workers * 4) |
| 41 | + batch_size = max(1, (num_records + batches - 1) // batches) |
| 42 | + |
| 43 | + def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]: |
| 44 | + out: list[tuple[int, str, list[int]]] = [] |
| 45 | + for j in range(start_idx, end_idx): |
| 46 | + rec_start = starts[j] |
| 47 | + rec_end = starts[j + 1] |
| 48 | + |
| 49 | + # Locate the end of the header line within this record's slice. |
| 50 | + nl = data.find(b"\n", rec_start, rec_end) |
| 51 | + if nl <= rec_start: |
| 52 | + continue # Malformed or header-only. |
| 53 | + |
| 54 | + record_id = data[rec_start + 1 : nl].decode("ascii").strip() |
| 55 | + |
| 56 | + # Contiguous sequence: drop the newlines so matches that straddle |
| 57 | + # line breaks are still found by ``bytes.find``. |
| 58 | + sequence = data[nl + 1 : rec_end].replace(b"\n", b"") |
| 59 | + |
| 60 | + positions: list[int] = [] |
| 61 | + s = 0 |
| 62 | + while True: |
| 63 | + p = sequence.find(pattern, s) |
| 64 | + if p == -1: |
| 65 | + break |
| 66 | + positions.append(p) |
| 67 | + s = p + 1 |
| 68 | + |
| 69 | + if positions: |
| 70 | + out.append((j, record_id, positions)) |
| 71 | + return out |
| 72 | + |
| 73 | + with ThreadPoolExecutor(max_workers=n_workers) as pool: |
| 74 | + futures = [ |
| 75 | + pool.submit(scan_batch, lo, min(lo + batch_size, num_records)) |
| 76 | + for lo in range(0, num_records, batch_size) |
| 77 | + ] |
| 78 | + chunks = [f.result() for f in futures] |
| 79 | + |
| 80 | + # Step 3: flatten and restore file order (record index is monotonic per |
| 81 | + # batch, but batches finish in arbitrary order). |
| 82 | + flat = [item for chunk in chunks for item in chunk] |
| 83 | + flat.sort(key=lambda triple: triple[0]) |
| 84 | + return [(rid, positions) for _, rid, positions in flat] |
0 commit comments