|
5 | 5 | own faster implementation. |
6 | 6 | """ |
7 | 7 |
|
8 | | -from __future__ import annotations |
| 8 | +from mmap import mmap, ACCESS_READ |
| 9 | +from concurrent.futures import ThreadPoolExecutor, wait |
9 | 10 |
|
10 | | -import os |
11 | | -from concurrent.futures import ThreadPoolExecutor |
| 11 | +def _subsearch(raw, record_id_start: int, data_start: int, data_end: int, pattern: bytes): |
| 12 | + plen = len(pattern) |
| 13 | + data = bytes(raw[data_start : data_end - 1]).replace(b"\n", b"") |
| 14 | + locations = [] |
| 15 | + loc = data.find(pattern) |
| 16 | + while loc != -1: |
| 17 | + locations.append(loc) |
| 18 | + loc = data.find(pattern, loc + plen) |
12 | 19 |
|
13 | | -_NL = 0x0A # b"\n" |
| 20 | + if not locations: |
| 21 | + return None |
14 | 22 |
|
| 23 | + record_id = raw[record_id_start : data_start - 1].decode("ascii") |
| 24 | + return (record_id, locations) |
15 | 25 |
|
16 | 26 | def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]: |
17 | | - with open(fasta_path, "rb") as f: |
18 | | - data = f.read() |
| 27 | + """Find every FASTA record whose sequence contains ``pattern``. |
19 | 28 |
|
20 | | - # Step 1: locate every record start. A record starts with ``>`` either at |
21 | | - # offset 0 or immediately after a ``\n``. |
22 | | - starts: list[int] = [] |
23 | | - i = 0 |
24 | | - while True: |
25 | | - p = data.find(b">", i) |
26 | | - if p == -1: |
27 | | - break |
28 | | - if p == 0 or data[p - 1] == _NL: |
29 | | - starts.append(p) |
30 | | - i = p + 1 |
31 | | - starts.append(len(data)) # sentinel marking the end of the last record. |
| 29 | + Returns ``[(record_id, [positions...]), ...]`` in file order. |
| 30 | + """ |
| 31 | + source = open(fasta_path, "rb") |
| 32 | + data = mmap(source.fileno(), 0, access=ACCESS_READ) |
32 | 33 |
|
33 | | - num_records = len(starts) - 1 |
34 | | - if num_records <= 0: |
35 | | - return [] |
| 34 | + last = -1 |
36 | 35 |
|
37 | | - # Step 2: parallel scan. Choose enough batches to keep workers balanced |
38 | | - # even when record sizes vary. |
39 | | - n_workers = max(1, os.cpu_count() or 1) |
40 | | - batches = max(1, n_workers * 4) |
41 | | - batch_size = max(1, (num_records + batches - 1) // batches) |
| 36 | + data_end = len(data) - 1 |
| 37 | + while data[data_end] == b"\n": |
| 38 | + data_end -= 1 |
42 | 39 |
|
43 | | - def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]: |
44 | | - out: list[tuple[int, str, list[int]]] = [] |
45 | | - for j in range(start_idx, end_idx): |
46 | | - rec_start = starts[j] |
47 | | - rec_end = starts[j + 1] |
| 40 | + with ThreadPoolExecutor(max_workers=16) as executor: |
| 41 | + records = [] |
| 42 | + while data_end > 0: |
| 43 | + gt_pos = data.rfind(b">", 0, data_end) |
| 44 | + if gt_pos == -1: |
| 45 | + raise Exception("expected greater than") |
48 | 46 |
|
49 | | - # Locate the end of the header line within this record's slice. |
50 | | - nl = data.find(b"\n", rec_start, rec_end) |
51 | | - if nl <= rec_start: |
52 | | - continue # Malformed or header-only. |
| 47 | + record_id_start = gt_pos + 1 |
53 | 48 |
|
54 | | - record_id = data[rec_start + 1 : nl].decode("ascii").strip() |
| 49 | + nl_pos = data.find(b"\n", record_id_start) |
| 50 | + if nl_pos == -1: |
| 51 | + raise Exception("expected new line") |
55 | 52 |
|
56 | | - # Contiguous sequence: drop the newlines so matches that straddle |
57 | | - # line breaks are still found by ``bytes.find``. |
58 | | - sequence = data[nl + 1 : rec_end].replace(b"\n", b"") |
| 53 | + data_start = nl_pos + 1 |
59 | 54 |
|
60 | | - positions: list[int] = [] |
61 | | - s = 0 |
62 | | - while True: |
63 | | - p = sequence.find(pattern, s) |
64 | | - if p == -1: |
65 | | - break |
66 | | - positions.append(p) |
67 | | - s = p + 1 |
| 55 | + records.append( |
| 56 | + executor.submit(_subsearch, data, record_id_start, data_start, data_end, pattern) |
| 57 | + ) |
| 58 | + data_end = gt_pos |
68 | 59 |
|
69 | | - if positions: |
70 | | - out.append((j, record_id, positions)) |
71 | | - return out |
72 | | - |
73 | | - with ThreadPoolExecutor(max_workers=n_workers) as pool: |
74 | | - futures = [ |
75 | | - pool.submit(scan_batch, lo, min(lo + batch_size, num_records)) |
76 | | - for lo in range(0, num_records, batch_size) |
77 | | - ] |
78 | | - chunks = [f.result() for f in futures] |
79 | | - |
80 | | - # Step 3: flatten and restore file order (record index is monotonic per |
81 | | - # batch, but batches finish in arbitrary order). |
82 | | - flat = [item for chunk in chunks for item in chunk] |
83 | | - flat.sort(key=lambda triple: triple[0]) |
84 | | - return [(rid, positions) for _, rid, positions in flat] |
| 60 | + results = [d.result() for d in records if d.result() is not None] |
| 61 | + results.reverse() |
| 62 | + return results |
0 commit comments