|
5 | 5 | own faster implementation. |
6 | 6 | """ |
7 | 7 |
|
8 | | -from mmap import mmap, ACCESS_READ |
9 | | -from concurrent.futures import ThreadPoolExecutor, wait |
| 8 | +from __future__ import annotations |
10 | 9 |
|
11 | | -def _subsearch(raw, record_id_start: int, data_start: int, data_end: int, pattern: bytes): |
12 | | - plen = len(pattern) |
13 | | - data = bytes(raw[data_start : data_end - 1]).replace(b"\n", b"") |
14 | | - locations = [] |
15 | | - loc = data.find(pattern) |
16 | | - while loc != -1: |
17 | | - locations.append(loc) |
18 | | - loc = data.find(pattern, loc + plen) |
| 10 | +import os |
| 11 | +from concurrent.futures import ThreadPoolExecutor |
19 | 12 |
|
20 | | - if not locations: |
21 | | - return None |
| 13 | +_NL = 0x0A |
22 | 14 |
|
23 | | - record_id = raw[record_id_start : data_start - 1].decode("ascii") |
24 | | - return (record_id, locations) |
25 | 15 |
|
26 | 16 | def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]: |
27 | 17 | """Find every FASTA record whose sequence contains ``pattern``. |
28 | 18 |
|
29 | 19 | Returns ``[(record_id, [positions...]), ...]`` in file order. |
30 | 20 | """ |
31 | | - source = open(fasta_path, "rb") |
32 | | - data = mmap(source.fileno(), 0, access=ACCESS_READ) |
| 21 | + with open(fasta_path, "rb") as f: |
| 22 | + data = f.read() |
33 | 23 |
|
34 | | - last = -1 |
| 24 | + # Step 1: locate every record start ('>' at offset 0 or after '\n'). |
| 25 | + starts: list[int] = [] |
| 26 | + i = 0 |
| 27 | + while True: |
| 28 | + p = data.find(b">", i) |
| 29 | + if p == -1: |
| 30 | + break |
| 31 | + if p == 0 or data[p - 1] == _NL: |
| 32 | + starts.append(p) |
| 33 | + i = p + 1 |
| 34 | + starts.append(len(data)) # sentinel |
35 | 35 |
|
36 | | - data_end = len(data) - 1 |
37 | | - while data[data_end] == b"\n": |
38 | | - data_end -= 1 |
| 36 | + num_records = len(starts) - 1 |
| 37 | + if num_records <= 0: |
| 38 | + return [] |
39 | 39 |
|
40 | | - with ThreadPoolExecutor(max_workers=16) as executor: |
41 | | - records = [] |
42 | | - while data_end > 0: |
43 | | - gt_pos = data.rfind(b">", 0, data_end) |
44 | | - if gt_pos == -1: |
45 | | - raise Exception("expected greater than") |
| 40 | + # Step 2: parallel scan with batched work units. |
| 41 | + n_workers = max(1, os.cpu_count() or 1) |
| 42 | + batches = max(1, n_workers * 4) |
| 43 | + batch_size = max(1, (num_records + batches - 1) // batches) |
46 | 44 |
|
47 | | - record_id_start = gt_pos + 1 |
| 45 | + def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]: |
| 46 | + out: list[tuple[int, str, list[int]]] = [] |
| 47 | + for j in range(start_idx, end_idx): |
| 48 | + rec_start = starts[j] |
| 49 | + rec_end = starts[j + 1] |
48 | 50 |
|
49 | | - nl_pos = data.find(b"\n", record_id_start) |
50 | | - if nl_pos == -1: |
51 | | - raise Exception("expected new line") |
| 51 | + nl = data.find(b"\n", rec_start, rec_end) |
| 52 | + if nl <= rec_start: |
| 53 | + continue |
52 | 54 |
|
53 | | - data_start = nl_pos + 1 |
| 55 | + # Strip newlines so cross-line matches are found. |
| 56 | + sequence = data[nl + 1 : rec_end].replace(b"\n", b"") |
54 | 57 |
|
55 | | - records.append( |
56 | | - executor.submit(_subsearch, data, record_id_start, data_start, data_end, pattern) |
57 | | - ) |
58 | | - data_end = gt_pos |
| 58 | + positions: list[int] = [] |
| 59 | + s = 0 |
| 60 | + while True: |
| 61 | + p = sequence.find(pattern, s) |
| 62 | + if p == -1: |
| 63 | + break |
| 64 | + positions.append(p) |
| 65 | + s = p + 1 |
59 | 66 |
|
60 | | - results = [d.result() for d in records if d.result() is not None] |
61 | | - results.reverse() |
62 | | - return results |
| 67 | + if positions: |
| 68 | + record_id = data[rec_start + 1 : nl].decode("ascii").strip() |
| 69 | + out.append((j, record_id, positions)) |
| 70 | + return out |
| 71 | + |
| 72 | + with ThreadPoolExecutor(max_workers=n_workers) as pool: |
| 73 | + futures = [ |
| 74 | + pool.submit(scan_batch, lo, min(lo + batch_size, num_records)) |
| 75 | + for lo in range(0, num_records, batch_size) |
| 76 | + ] |
| 77 | + chunks = [f.result() for f in futures] |
| 78 | + |
| 79 | + # Step 3: flatten and restore file order. |
| 80 | + flat = [item for chunk in chunks for item in chunk] |
| 81 | + flat.sort(key=lambda triple: triple[0]) |
| 82 | + return [(rid, positions) for _, rid, positions in flat] |
0 commit comments