Skip to content

Commit b0dd57f

Browse files
Fix performance regressions in histogram and DNA solutions
1 parent 1094bc5 commit b0dd57f

2 files changed

Lines changed: 73 additions & 62 deletions

File tree

rounds/1_histogram/solution.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,23 @@
44
passes out of the box. Replace the body of ``compute_histogram`` with your
55
own faster implementation.
66
"""
7-
from collections import defaultdict
8-
from mmap import mmap, ACCESS_READ
97

10-
def b2i(low: int, high: int) -> int:
11-
return high + (low << 8)
8+
import numpy as np
129

13-
def i2b(x: int) -> bytes:
14-
return bytes([(x & 0xFF00) >> 8, x & 0xFF])
1510

1611
def compute_histogram(path: str) -> dict[bytes, int]:
1712
"""Frequency of every 2-byte bigram in the file at ``path``."""
18-
# Step 1: read the whole file into memory as a single bytes object.
19-
counts = [0 for _ in range(2**16)]
20-
21-
source = open(path, "rb", buffering=0)
22-
data = mmap(source.fileno(), 0, access=ACCESS_READ)
23-
24-
# Step 2: slide a 2-byte window across the buffer. For ``b"ABCD"`` the
25-
# iterations produce ``b"AB"``, ``b"BC"``, then ``b"CD"``. For each window,
26-
# bump the matching bucket in a ``dict`` keyed by the bigram itself.
27-
previous = data[0]
28-
for i in range(len(data) - 1):
29-
current = data[i + 1]
30-
counts[current + (previous << 8)] += 1
31-
previous = current
32-
33-
return {
34-
i2b(idx): value for idx, value in enumerate(counts) if value != 0
35-
}
13+
with open(path, "rb") as f:
14+
data = f.read()
15+
16+
arr = np.frombuffer(data, dtype=np.uint8)
17+
18+
# Vectorised bigram index: first_byte * 256 + second_byte
19+
bigram_indices = arr[:-1].astype(np.uint16) * 256 + arr[1:]
20+
21+
# Count every bigram in a single pass (C-level loop inside numpy)
22+
counts = np.bincount(bigram_indices, minlength=65536)
23+
24+
# Build the result dict from non-zero entries only
25+
nonzero = np.flatnonzero(counts)
26+
return {int(idx).to_bytes(2, "big"): int(counts[idx]) for idx in nonzero}

rounds/3_dna/solution.py

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,58 +5,78 @@
55
own faster implementation.
66
"""
77

8-
from mmap import mmap, ACCESS_READ
9-
from concurrent.futures import ThreadPoolExecutor, wait
8+
from __future__ import annotations
109

11-
def _subsearch(raw, record_id_start: int, data_start: int, data_end: int, pattern: bytes):
12-
plen = len(pattern)
13-
data = bytes(raw[data_start : data_end - 1]).replace(b"\n", b"")
14-
locations = []
15-
loc = data.find(pattern)
16-
while loc != -1:
17-
locations.append(loc)
18-
loc = data.find(pattern, loc + plen)
10+
import os
11+
from concurrent.futures import ThreadPoolExecutor
1912

20-
if not locations:
21-
return None
13+
_NL = 0x0A
2214

23-
record_id = raw[record_id_start : data_start - 1].decode("ascii")
24-
return (record_id, locations)
2515

2616
def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]:
2717
"""Find every FASTA record whose sequence contains ``pattern``.
2818
2919
Returns ``[(record_id, [positions...]), ...]`` in file order.
3020
"""
31-
source = open(fasta_path, "rb")
32-
data = mmap(source.fileno(), 0, access=ACCESS_READ)
21+
with open(fasta_path, "rb") as f:
22+
data = f.read()
3323

34-
last = -1
24+
# Step 1: locate every record start ('>' at offset 0 or after '\n').
25+
starts: list[int] = []
26+
i = 0
27+
while True:
28+
p = data.find(b">", i)
29+
if p == -1:
30+
break
31+
if p == 0 or data[p - 1] == _NL:
32+
starts.append(p)
33+
i = p + 1
34+
starts.append(len(data)) # sentinel
3535

36-
data_end = len(data) - 1
37-
while data[data_end] == b"\n":
38-
data_end -= 1
36+
num_records = len(starts) - 1
37+
if num_records <= 0:
38+
return []
3939

40-
with ThreadPoolExecutor(max_workers=16) as executor:
41-
records = []
42-
while data_end > 0:
43-
gt_pos = data.rfind(b">", 0, data_end)
44-
if gt_pos == -1:
45-
raise Exception("expected greater than")
40+
# Step 2: parallel scan with batched work units.
41+
n_workers = max(1, os.cpu_count() or 1)
42+
batches = max(1, n_workers * 4)
43+
batch_size = max(1, (num_records + batches - 1) // batches)
4644

47-
record_id_start = gt_pos + 1
45+
def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]:
46+
out: list[tuple[int, str, list[int]]] = []
47+
for j in range(start_idx, end_idx):
48+
rec_start = starts[j]
49+
rec_end = starts[j + 1]
4850

49-
nl_pos = data.find(b"\n", record_id_start)
50-
if nl_pos == -1:
51-
raise Exception("expected new line")
51+
nl = data.find(b"\n", rec_start, rec_end)
52+
if nl <= rec_start:
53+
continue
5254

53-
data_start = nl_pos + 1
55+
# Strip newlines so cross-line matches are found.
56+
sequence = data[nl + 1 : rec_end].replace(b"\n", b"")
5457

55-
records.append(
56-
executor.submit(_subsearch, data, record_id_start, data_start, data_end, pattern)
57-
)
58-
data_end = gt_pos
58+
positions: list[int] = []
59+
s = 0
60+
while True:
61+
p = sequence.find(pattern, s)
62+
if p == -1:
63+
break
64+
positions.append(p)
65+
s = p + 1
5966

60-
results = [d.result() for d in records if d.result() is not None]
61-
results.reverse()
62-
return results
67+
if positions:
68+
record_id = data[rec_start + 1 : nl].decode("ascii").strip()
69+
out.append((j, record_id, positions))
70+
return out
71+
72+
with ThreadPoolExecutor(max_workers=n_workers) as pool:
73+
futures = [
74+
pool.submit(scan_batch, lo, min(lo + batch_size, num_records))
75+
for lo in range(0, num_records, batch_size)
76+
]
77+
chunks = [f.result() for f in futures]
78+
79+
# Step 3: flatten and restore file order.
80+
flat = [item for chunk in chunks for item in chunk]
81+
flat.sort(key=lambda triple: triple[0])
82+
return [(rid, positions) for _, rid, positions in flat]

0 commit comments

Comments
 (0)