|
6 | 6 | """ |
7 | 7 |
|
8 | 8 | from .baseline import find_matches as _baseline |
| 9 | +import regex |
| 10 | +from multiprocessing.pool import ThreadPool |
| 11 | + |
| 12 | +def match(record, pattern_str): |
| 13 | + if not record.strip(): |
| 14 | + return None, [] |
| 15 | + |
| 16 | + # split record ID |
| 17 | + lines = record.split("\n") |
| 18 | + record_id = lines[0].strip() |
| 19 | + sequence = "".join(lines[1:]).replace(" ", "") |
| 20 | + |
| 21 | + # regex pattern match, get position if match |
| 22 | + match_inds = [] |
| 23 | + for match in regex.finditer(pattern_str, sequence, overlapped=True): |
| 24 | + match_inds.append(match.start()) |
| 25 | + |
| 26 | + return record_id, match_inds |
9 | 27 |
|
10 | 28 |
|
11 | 29 | def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]: |
12 | 30 | """Find every FASTA record whose sequence contains ``pattern``. |
13 | 31 |
|
14 | 32 | Returns ``[(record_id, [positions...]), ...]`` in file order. |
15 | 33 | """ |
16 | | - # TODO: remove this delegation and write your own implementation here. |
17 | | - return _baseline(fasta_path, pattern) |
| 34 | + pattern_str = pattern.decode("ascii") |
| 35 | + with open(fasta_path, "r") as f: |
| 36 | + text = f.read() |
| 37 | + |
| 38 | + results = [] |
| 39 | + records = text.split(">") |
| 40 | + args = [(record, pattern_str) for record in records] |
| 41 | + |
| 42 | + with ThreadPool(10) as pool: |
| 43 | + |
| 44 | + for record_id, match_inds in pool.starmap(match, args): |
| 45 | + if len(match_inds) > 0: |
| 46 | + # append to results |
| 47 | + results.append((record_id, match_inds)) |
| 48 | + |
| 49 | + return results |
0 commit comments