|
5 | 5 | own faster implementation. |
6 | 6 | """ |
7 | 7 |
|
8 | | -from .baseline import find_matches as _baseline |
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +import os |
| 11 | +from concurrent.futures import ThreadPoolExecutor |
| 12 | + |
| 13 | +_NL = 0x0A # b"\n" |
9 | 14 |
|
10 | 15 |
|
11 | 16 | def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]: |
12 | | - """Find every FASTA record whose sequence contains ``pattern``. |
| 17 | + with open(fasta_path, "rb") as f: |
| 18 | + data = f.read() |
| 19 | + |
| 20 | + # Step 1: locate every record start. A record starts with ``>`` either at |
| 21 | + # offset 0 or immediately after a ``\n``. |
| 22 | + starts: list[int] = [] |
| 23 | + i = 0 |
| 24 | + while True: |
| 25 | + p = data.find(b">", i) |
| 26 | + if p == -1: |
| 27 | + break |
| 28 | + if p == 0 or data[p - 1] == _NL: |
| 29 | + starts.append(p) |
| 30 | + i = p + 1 |
| 31 | + starts.append(len(data)) # sentinel marking the end of the last record. |
| 32 | + |
| 33 | + num_records = len(starts) - 1 |
| 34 | + if num_records <= 0: |
| 35 | + return [] |
| 36 | + |
| 37 | + # Step 2: parallel scan. Choose enough batches to keep workers balanced |
| 38 | + # even when record sizes vary. |
| 39 | + n_workers = max(1, os.cpu_count() or 1) |
| 40 | + batches = max(1, n_workers * 4) |
| 41 | + batch_size = max(1, (num_records + batches - 1) // batches) |
| 42 | + |
| 43 | + def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]: |
| 44 | + out: list[tuple[int, str, list[int]]] = [] |
| 45 | + for j in range(start_idx, end_idx): |
| 46 | + rec_start = starts[j] |
| 47 | + rec_end = starts[j + 1] |
| 48 | + |
| 49 | + # Locate the end of the header line within this record's slice. |
| 50 | + nl = data.find(b"\n", rec_start, rec_end) |
| 51 | + if nl <= rec_start: |
| 52 | + continue # Malformed or header-only. |
| 53 | + |
| 54 | + record_id = data[rec_start + 1 : nl].decode("ascii").strip() |
| 55 | + |
| 56 | + # Contiguous sequence: drop the newlines so matches that straddle |
| 57 | + # line breaks are still found by ``bytes.find``. |
| 58 | + sequence = data[nl + 1 : rec_end].replace(b"\n", b"") |
| 59 | + |
| 60 | + positions: list[int] = [] |
| 61 | + s = 0 |
| 62 | + while True: |
| 63 | + p = sequence.find(pattern, s) |
| 64 | + if p == -1: |
| 65 | + break |
| 66 | + positions.append(p) |
| 67 | + s = p + 1 |
| 68 | + |
| 69 | + if positions: |
| 70 | + out.append((j, record_id, positions)) |
| 71 | + return out |
| 72 | + |
| 73 | + with ThreadPoolExecutor(max_workers=n_workers) as pool: |
| 74 | + futures = [ |
| 75 | + pool.submit(scan_batch, lo, min(lo + batch_size, num_records)) |
| 76 | + for lo in range(0, num_records, batch_size) |
| 77 | + ] |
| 78 | + chunks = [f.result() for f in futures] |
13 | 79 |
|
14 | | - Returns ``[(record_id, [positions...]), ...]`` in file order. |
15 | | - """ |
16 | | - # TODO: remove this delegation and write your own implementation here. |
17 | | - return _baseline(fasta_path, pattern) |
| 80 | + # Step 3: flatten and restore file order (record index is monotonic per |
| 81 | + # batch, but batches finish in arbitrary order). |
| 82 | + flat = [item for chunk in chunks for item in chunk] |
| 83 | + flat.sort(key=lambda triple: triple[0]) |
| 84 | + return [(rid, positions) for _, rid, positions in flat] |
0 commit comments