Skip to content

Commit de6ea69

Browse files
committed
part3: use bytes and batches
1 parent 0d70ff6 commit de6ea69

1 file changed

Lines changed: 73 additions & 98 deletions

File tree

rounds/3_dna/solution.py

Lines changed: 73 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -5,105 +5,80 @@
55
own faster implementation.
66
"""
77

8+
from __future__ import annotations
9+
10+
import os
811
from concurrent.futures import ThreadPoolExecutor
9-
from threading import Thread
10-
11-
12-
def find_matches_in_sequence(
13-
record_id: str,
14-
sequence: str,
15-
pattern_str: str,
16-
matches: list[tuple[str, list[int]]],
17-
):
18-
"""Find matches in a single sequence and append to the shared matches list."""
19-
positions: list[int] = []
20-
start = 0
21-
while True:
22-
pos = sequence.find(pattern_str, start)
23-
if pos == -1:
24-
break
25-
positions.append(pos)
26-
start = pos + 1
27-
28-
if positions:
29-
matches.append((record_id, positions))
30-
31-
32-
def find_matches_many_threads(
33-
fasta_path: str, pattern: bytes
34-
) -> list[tuple[str, list[int]]]:
35-
# Step 1: read the whole FASTA file as text and decode the pattern so the
36-
# search below can use a single ``str`` API.
37-
pattern_str = pattern.decode("ascii")
38-
with open(fasta_path, "r") as f:
39-
text = f.read()
40-
41-
matches: list[tuple[str, list[int]]] = []
42-
43-
# Preprocess the sequences
44-
sequences = []
45-
for record in text.split(">"):
46-
if not record.strip():
47-
continue
48-
lines = record.split("\n")
49-
record_id = lines[0].strip()
50-
sequence = "".join(lines[1:]).replace(" ", "")
51-
sequences.append((record_id, sequence))
52-
threads = []
53-
for record_id, sequence in sequences:
54-
thread = Thread(
55-
target=find_matches_in_sequence,
56-
args=(record_id, sequence, pattern_str, matches),
57-
)
58-
thread.start()
59-
threads.append(thread)
60-
# Wait for all threads to finish
61-
print(f"Waiting for {len(threads)} threads to finish...")
62-
for thread in threads:
63-
thread.join()
64-
65-
return matches
12+
13+
_NL = 0x0A # b"\n"
6614

6715

6816
def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]:
69-
# Step 1: read the whole FASTA file as text and decode the pattern so the
70-
# search below can use a single ``str`` API.
71-
pattern_str = pattern.decode("ascii")
72-
with open(fasta_path, "r") as f:
73-
text = f.read()
74-
75-
matches: list[tuple[str, list[int]]] = []
76-
77-
# Preprocess the sequences
78-
sequences = []
79-
for record in text.split(">"):
80-
if not record.strip():
81-
continue
82-
lines = record.split("\n")
83-
record_id = lines[0].strip()
84-
sequence = "".join(lines[1:]).replace(" ", "")
85-
sequences.append((record_id, sequence))
86-
87-
# Create a pool of threads
88-
pool = ThreadPoolExecutor(max_workers=16)
89-
90-
for record_id, sequence in sequences:
91-
pool.submit(
92-
find_matches_in_sequence,
93-
record_id,
94-
sequence,
95-
pattern_str,
96-
matches,
97-
)
98-
# Or
99-
# pool.map(
100-
# lambda args: find_matches_in_sequence(*args),
101-
# [
102-
# (record_id, sequence, pattern_str, matches)
103-
# for record_id, sequence in sequences
104-
# ],
105-
# )
106-
# Wait for all threads to finish
107-
pool.shutdown(wait=True)
108-
109-
return matches
17+
with open(fasta_path, "rb") as f:
18+
data = f.read()
19+
20+
# Step 1: locate every record start. A record starts with ``>`` either at
21+
# offset 0 or immediately after a ``\n``.
22+
starts: list[int] = []
23+
i = 0
24+
while True:
25+
p = data.find(b">", i)
26+
if p == -1:
27+
break
28+
if p == 0 or data[p - 1] == _NL:
29+
starts.append(p)
30+
i = p + 1
31+
starts.append(len(data)) # sentinel marking the end of the last record.
32+
33+
num_records = len(starts) - 1
34+
if num_records <= 0:
35+
return []
36+
37+
# Step 2: parallel scan. Choose enough batches to keep workers balanced
38+
# even when record sizes vary.
39+
n_workers = max(1, os.cpu_count() or 1)
40+
batches = max(1, n_workers * 4)
41+
batch_size = max(1, (num_records + batches - 1) // batches)
42+
43+
def scan_batch(start_idx: int, end_idx: int) -> list[tuple[int, str, list[int]]]:
44+
out: list[tuple[int, str, list[int]]] = []
45+
for j in range(start_idx, end_idx):
46+
rec_start = starts[j]
47+
rec_end = starts[j + 1]
48+
49+
# Locate the end of the header line within this record's slice.
50+
nl = data.find(b"\n", rec_start, rec_end)
51+
if nl <= rec_start:
52+
continue # Malformed or header-only.
53+
54+
record_id = data[rec_start + 1 : nl].decode("ascii").strip()
55+
56+
# Contiguous sequence: drop the newlines so matches that straddle
57+
# line breaks are still found by ``bytes.find``.
58+
sequence = data[nl + 1 : rec_end].replace(b"\n", b"")
59+
60+
positions: list[int] = []
61+
s = 0
62+
while True:
63+
p = sequence.find(pattern, s)
64+
if p == -1:
65+
break
66+
positions.append(p)
67+
s = p + 1
68+
69+
if positions:
70+
out.append((j, record_id, positions))
71+
return out
72+
73+
with ThreadPoolExecutor(max_workers=n_workers) as pool:
74+
futures = [
75+
pool.submit(scan_batch, lo, min(lo + batch_size, num_records))
76+
for lo in range(0, num_records, batch_size)
77+
]
78+
chunks = [f.result() for f in futures]
79+
80+
# Step 3: flatten and restore file order (record index is monotonic per
81+
# batch, but batches finish in arbitrary order).
82+
flat = [item for chunk in chunks for item in chunk]
83+
flat.sort(key=lambda triple: triple[0])
84+
return [(rid, positions) for _, rid, positions in flat]

0 commit comments

Comments
 (0)