|
7 | 7 |
|
8 | 8 | from __future__ import annotations |
9 | 9 |
|
10 | | -import functools |
11 | 10 | import os |
12 | 11 | from concurrent.futures import ThreadPoolExecutor |
| 12 | +from mmap import ACCESS_READ, mmap |
| 13 | +from os import fstat |
13 | 14 |
|
14 | | -_DELETE_TABLE = bytes.maketrans(b"", b"") |
15 | | -_DELETE_CHARS = b"\n \r" |
16 | 15 | _NUM_WORKERS = os.cpu_count() or 4 |
17 | 16 |
|
18 | 17 |
|
19 | | -@functools.lru_cache(maxsize=4) |
20 | | -def _load(fasta_path: str) -> bytes: |
21 | | - with open(fasta_path, "rb") as f: |
22 | | - data = f.read() |
23 | | - boundaries = [] |
24 | | - pos = data.find(b">") |
25 | | - while pos != -1: |
26 | | - nxt = data.find(b">", pos + 1) |
27 | | - boundaries.append((pos, nxt if nxt != -1 else len(data))) |
28 | | - pos = nxt |
29 | | - return data, boundaries |
30 | | - |
31 | | - |
32 | | -def _search_chunk( |
33 | | - data: bytes, |
34 | | - pattern: bytes, |
35 | | - records: list[tuple[int, int]], |
36 | | -) -> list[tuple[str, list[int]]]: |
37 | | - """Process a batch of (header_start, next_record_start) pairs.""" |
38 | | - results: list[tuple[str, list[int]]] = [] |
39 | | - for rec_start, rec_end in records: |
40 | | - nl = data.index(b"\n", rec_start) |
41 | | - seq = data[nl + 1 : rec_end].translate(_DELETE_TABLE, _DELETE_CHARS) |
42 | | - |
43 | | - if pattern not in seq: |
44 | | - continue |
45 | | - |
46 | | - record_id = data[rec_start + 1 : nl].strip().decode("ascii") |
47 | | - |
48 | | - positions: list[int] = [] |
49 | | - start = 0 |
50 | | - _find = seq.find |
51 | | - while True: |
52 | | - idx = _find(pattern, start) |
53 | | - if idx == -1: |
54 | | - break |
55 | | - positions.append(idx) |
56 | | - start = idx + 1 |
57 | | - |
58 | | - if positions: |
59 | | - results.append((record_id, positions)) |
60 | | - return results |
| 18 | +def _scan_record(record: bytes, pattern: bytes) -> tuple[str, list[int]] | None: |
| 19 | + """Scan one FASTA record for all occurrences of ``pattern``. |
61 | 20 |
|
| 21 | + Returns the record id and every zero-based match position, or ``None`` if |
| 22 | + the record is empty or does not contain the pattern. |
| 23 | + """ |
62 | 24 |
|
63 | | -def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]: |
64 | | - """Find every FASTA record whose sequence contains ``pattern``. |
| 25 | + if not record.strip(): |
| 26 | + return None |
65 | 27 |
|
66 | | - Returns ``[(record_id, [positions...]), ...]`` in file order. |
67 | | - """ |
68 | | - data, boundaries = _load(fasta_path) |
| 28 | + # Parition DNA record into header and DNA sequence |
| 29 | + header, _, body = record.partition(b"\n") |
| 30 | + record_id = header.strip().decode("ascii") |
| 31 | + |
| 32 | + # Clean up data before parsing |
| 33 | + sequence = body.replace(b"\n", b"").replace(b"\r", b"").replace(b" ", b"") |
| 34 | + |
| 35 | + positions: list[int] = [] |
| 36 | + start = 0 |
69 | 37 |
|
70 | | - if not boundaries: |
71 | | - return [] |
| 38 | + # Advance by one after each hit so overlapping matches are included. |
| 39 | + while True: |
| 40 | + pos = sequence.find(pattern, start) |
| 41 | + if pos == -1: |
| 42 | + break |
| 43 | + positions.append(pos) |
| 44 | + start = pos + 1 |
72 | 45 |
|
73 | | - n = len(boundaries) |
74 | | - chunk_size = max(1, n // _NUM_WORKERS) |
75 | | - chunks = [boundaries[i : i + chunk_size] for i in range(0, n, chunk_size)] |
| 46 | + if not positions: |
| 47 | + return None |
76 | 48 |
|
77 | | - matches: list[tuple[str, list[int]]] = [] |
78 | | - with ThreadPoolExecutor(max_workers=_NUM_WORKERS) as executor: |
79 | | - futures = [executor.submit(_search_chunk, data, pattern, chunk) for chunk in chunks] |
80 | | - for future in futures: |
81 | | - matches.extend(future.result()) |
| 49 | + return record_id, positions |
82 | 50 |
|
83 | | - return matches |
| 51 | + |
| 52 | +def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]: |
| 53 | + """Find every FASTA record whose sequence contains ``pattern``. |
| 54 | +
|
| 55 | + Returns ``[(record_id, [positions...]), ...]`` in file order. |
| 56 | + """ |
| 57 | + with open(fasta_path, "rb") as f: |
| 58 | + if fstat(f.fileno()).st_size == 0: |
| 59 | + return [] |
| 60 | + |
| 61 | + with mmap(f.fileno(), 0, access=ACCESS_READ) as text: |
| 62 | + # Read the file as an mmap and break it up into DNA records |
| 63 | + records: list[bytes] = [] |
| 64 | + start = text.find(b">") |
| 65 | + while start != -1: |
| 66 | + end = text.find(b">", start + 1) |
| 67 | + if end == -1: |
| 68 | + record = text[start + 1 :] |
| 69 | + start = -1 |
| 70 | + else: |
| 71 | + record = text[start + 1 : end] |
| 72 | + start = end |
| 73 | + |
| 74 | + if record.strip(): |
| 75 | + records.append(record) |
| 76 | + |
| 77 | + with ThreadPoolExecutor(max_workers=_NUM_WORKERS) as executor: |
| 78 | + results = executor.map(lambda record: _scan_record(record, pattern), records) |
| 79 | + |
| 80 | + return [result for result in results if result is not None] |
0 commit comments