Skip to content

Commit 280ea70

Browse files
committed
part3: threadpool
1 parent cafc8ef commit 280ea70

1 file changed

Lines changed: 89 additions & 6 deletions

File tree

rounds/3_dna/solution.py

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,96 @@
55
own faster implementation.
66
"""
77

8-
from .baseline import find_matches as _baseline
8+
from concurrent.futures import ThreadPoolExecutor
9+
from threading import Thread
10+
11+
12+
def find_matches_in_sequence(
13+
record_id: str,
14+
sequence: str,
15+
pattern_str: str,
16+
matches: list[tuple[str, list[int]]],
17+
):
18+
"""Find matches in a single sequence and append to the shared matches list."""
19+
positions: list[int] = []
20+
start = 0
21+
while True:
22+
pos = sequence.find(pattern_str, start)
23+
if pos == -1:
24+
break
25+
positions.append(pos)
26+
start = pos + 1
27+
28+
if positions:
29+
matches.append((record_id, positions))
30+
31+
32+
def find_matches_many_threads(
33+
fasta_path: str, pattern: bytes
34+
) -> list[tuple[str, list[int]]]:
35+
# Step 1: read the whole FASTA file as text and decode the pattern so the
36+
# search below can use a single ``str`` API.
37+
pattern_str = pattern.decode("ascii")
38+
with open(fasta_path, "r") as f:
39+
text = f.read()
40+
41+
matches: list[tuple[str, list[int]]] = []
42+
43+
# Preprocess the sequences
44+
sequences = []
45+
for record in text.split(">"):
46+
if not record.strip():
47+
continue
48+
lines = record.split("\n")
49+
record_id = lines[0].strip()
50+
sequence = "".join(lines[1:]).replace(" ", "")
51+
sequences.append((record_id, sequence))
52+
threads = []
53+
for record_id, sequence in sequences:
54+
thread = Thread(
55+
target=find_matches_in_sequence,
56+
args=(record_id, sequence, pattern_str, matches),
57+
)
58+
thread.start()
59+
threads.append(thread)
60+
# Wait for all threads to finish
61+
print(f"Waiting for {len(threads)} threads to finish...")
62+
for thread in threads:
63+
thread.join()
64+
65+
return matches
966

1067

1168
def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]:
12-
"""Find every FASTA record whose sequence contains ``pattern``.
69+
# Step 1: read the whole FASTA file as text and decode the pattern so the
70+
# search below can use a single ``str`` API.
71+
pattern_str = pattern.decode("ascii")
72+
with open(fasta_path, "r") as f:
73+
text = f.read()
74+
75+
matches: list[tuple[str, list[int]]] = []
76+
77+
# Preprocess the sequences
78+
sequences = []
79+
for record in text.split(">"):
80+
if not record.strip():
81+
continue
82+
lines = record.split("\n")
83+
record_id = lines[0].strip()
84+
sequence = "".join(lines[1:]).replace(" ", "")
85+
sequences.append((record_id, sequence))
86+
87+
# Create a pool of threads
88+
pool = ThreadPoolExecutor(max_workers=len(sequences))
89+
for record_id, sequence in sequences:
90+
pool.submit(
91+
find_matches_in_sequence,
92+
record_id,
93+
sequence,
94+
pattern_str,
95+
matches,
96+
)
97+
# Wait for all threads to finish
98+
pool.shutdown(wait=True)
1399

14-
Returns ``[(record_id, [positions...]), ...]`` in file order.
15-
"""
16-
# TODO: remove this delegation and write your own implementation here.
17-
return _baseline(fasta_path, pattern)
100+
return matches

0 commit comments

Comments
 (0)