Skip to content

Commit dde5cca

Browse files
committed
optimize byte code perf
1 parent 01d0c41 commit dde5cca

1 file changed

Lines changed: 58 additions & 61 deletions

File tree

rounds/3_dna/solution.py

Lines changed: 58 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,77 +7,74 @@
77

88
from __future__ import annotations
99

10-
import functools
1110
import os
1211
from concurrent.futures import ThreadPoolExecutor
12+
from mmap import ACCESS_READ, mmap
13+
from os import fstat
1314

14-
_DELETE_TABLE = bytes.maketrans(b"", b"")
15-
_DELETE_CHARS = b"\n \r"
1615
_NUM_WORKERS = os.cpu_count() or 4
1716

1817

19-
@functools.lru_cache(maxsize=4)
20-
def _load(fasta_path: str) -> bytes:
21-
with open(fasta_path, "rb") as f:
22-
data = f.read()
23-
boundaries = []
24-
pos = data.find(b">")
25-
while pos != -1:
26-
nxt = data.find(b">", pos + 1)
27-
boundaries.append((pos, nxt if nxt != -1 else len(data)))
28-
pos = nxt
29-
return data, boundaries
30-
31-
32-
def _search_chunk(
33-
data: bytes,
34-
pattern: bytes,
35-
records: list[tuple[int, int]],
36-
) -> list[tuple[str, list[int]]]:
37-
"""Process a batch of (header_start, next_record_start) pairs."""
38-
results: list[tuple[str, list[int]]] = []
39-
for rec_start, rec_end in records:
40-
nl = data.index(b"\n", rec_start)
41-
seq = data[nl + 1 : rec_end].translate(_DELETE_TABLE, _DELETE_CHARS)
42-
43-
if pattern not in seq:
44-
continue
45-
46-
record_id = data[rec_start + 1 : nl].strip().decode("ascii")
47-
48-
positions: list[int] = []
49-
start = 0
50-
_find = seq.find
51-
while True:
52-
idx = _find(pattern, start)
53-
if idx == -1:
54-
break
55-
positions.append(idx)
56-
start = idx + 1
57-
58-
if positions:
59-
results.append((record_id, positions))
60-
return results
18+
def _scan_record(record: bytes, pattern: bytes) -> tuple[str, list[int]] | None:
19+
"""Scan one FASTA record for all occurrences of ``pattern``.
6120
21+
Returns the record id and every zero-based match position, or ``None`` if
22+
the record is empty or does not contain the pattern.
23+
"""
6224

63-
def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]:
64-
"""Find every FASTA record whose sequence contains ``pattern``.
25+
if not record.strip():
26+
return None
6527

66-
Returns ``[(record_id, [positions...]), ...]`` in file order.
67-
"""
68-
data, boundaries = _load(fasta_path)
28+
# Parition DNA record into header and DNA sequence
29+
header, _, body = record.partition(b"\n")
30+
record_id = header.strip().decode("ascii")
31+
32+
# Clean up data before parsing
33+
sequence = body.replace(b"\n", b"").replace(b"\r", b"").replace(b" ", b"")
34+
35+
positions: list[int] = []
36+
start = 0
6937

70-
if not boundaries:
71-
return []
38+
# Advance by one after each hit so overlapping matches are included.
39+
while True:
40+
pos = sequence.find(pattern, start)
41+
if pos == -1:
42+
break
43+
positions.append(pos)
44+
start = pos + 1
7245

73-
n = len(boundaries)
74-
chunk_size = max(1, n // _NUM_WORKERS)
75-
chunks = [boundaries[i : i + chunk_size] for i in range(0, n, chunk_size)]
46+
if not positions:
47+
return None
7648

77-
matches: list[tuple[str, list[int]]] = []
78-
with ThreadPoolExecutor(max_workers=_NUM_WORKERS) as executor:
79-
futures = [executor.submit(_search_chunk, data, pattern, chunk) for chunk in chunks]
80-
for future in futures:
81-
matches.extend(future.result())
49+
return record_id, positions
8250

83-
return matches
51+
52+
def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]:
53+
"""Find every FASTA record whose sequence contains ``pattern``.
54+
55+
Returns ``[(record_id, [positions...]), ...]`` in file order.
56+
"""
57+
with open(fasta_path, "rb") as f:
58+
if fstat(f.fileno()).st_size == 0:
59+
return []
60+
61+
with mmap(f.fileno(), 0, access=ACCESS_READ) as text:
62+
# Read the file as an mmap and break it up into DNA records
63+
records: list[bytes] = []
64+
start = text.find(b">")
65+
while start != -1:
66+
end = text.find(b">", start + 1)
67+
if end == -1:
68+
record = text[start + 1 :]
69+
start = -1
70+
else:
71+
record = text[start + 1 : end]
72+
start = end
73+
74+
if record.strip():
75+
records.append(record)
76+
77+
with ThreadPoolExecutor(max_workers=_NUM_WORKERS) as executor:
78+
results = executor.map(lambda record: _scan_record(record, pattern), records)
79+
80+
return [result for result in results if result is not None]

0 commit comments

Comments
 (0)