Skip to content

Commit 5b07a1b

Browse files
Rusty JohnsonRusty Johnson
authored andcommitted
rsjohnson3: Speedup
1 parent 9465158 commit 5b07a1b

1 file changed

Lines changed: 134 additions & 75 deletions

File tree

rounds/3_dna/solution.py

Lines changed: 134 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import mmap
34
import os
45
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
56
from os import PathLike
@@ -8,63 +9,75 @@
89

910
Pathish = Union[str, bytes, PathLike[str], PathLike[bytes]]
1011

11-
Record = tuple[int, str, bytearray]
12+
# (record_index, record_start_offset, record_end_offset)
13+
Span = tuple[int, int, int]
14+
15+
# (record_index, record_id, match_positions)
1216
SearchResult = tuple[int, str, list[int]]
1317

18+
# Baseline behavior removes spaces and newlines from sequence text.
19+
# In binary mode we also remove '\r' to match text-mode universal newlines.
20+
_DELETE_SEQUENCE_BYTES = b" \r\n"
1421

15-
def _iter_fasta_records(fasta_path: Pathish) -> Iterator[Record]:
16-
"""
17-
Yield FASTA records as:
1822

19-
(record_index, record_id, sequence)
23+
def _default_worker_count() -> int:
24+
# Python 3.13+ may expose process_cpu_count(), which respects CPU limits.
25+
process_cpu_count = getattr(os, "process_cpu_count", None)
26+
27+
if process_cpu_count is not None:
28+
count = process_cpu_count()
29+
else:
30+
count = os.cpu_count()
31+
32+
return count or 1
33+
2034

21-
The sequence is accumulated as bytes, with literal spaces removed to match
22-
the baseline behavior.
35+
def _iter_record_spans(mm: mmap.mmap, size: int) -> Iterator[Span]:
2336
"""
37+
Yield FASTA record byte ranges.
2438
25-
record_id: str | None = None
26-
sequence = bytearray()
27-
index = 0
39+
Assumes valid FASTA-style records where headers begin with '>' at the start
40+
of a line. This is faster than splitting the whole file on b'>'.
41+
"""
2842

29-
with open(fasta_path, "rb") as f:
30-
for raw_line in f:
31-
if raw_line[:1] == b">":
32-
if record_id is not None:
33-
yield index, record_id, sequence
34-
index += 1
43+
if size == 0:
44+
return
3545

36-
record_id = raw_line[1:].strip().decode("ascii")
37-
sequence = bytearray()
38-
continue
46+
if mm[:1] == b">":
47+
start = 0
48+
else:
49+
marker = mm.find(b"\n>")
50+
if marker < 0:
51+
return
52+
start = marker + 1
3953

40-
# Ignore preamble before the first FASTA header.
41-
if record_id is None:
42-
continue
54+
index = 0
4355

44-
line = raw_line.rstrip(b"\r\n")
56+
while start < size:
57+
next_marker = mm.find(b"\n>", start + 1)
58+
end = size if next_marker < 0 else next_marker
4559

46-
# Match the baseline's `.replace(" ", "")`.
47-
if b" " in line:
48-
line = line.replace(b" ", b"")
60+
yield index, start, end
4961

50-
sequence.extend(line)
62+
index += 1
5163

52-
if record_id is not None:
53-
yield index, record_id, sequence
64+
if next_marker < 0:
65+
break
5466

67+
start = next_marker + 1
5568

56-
def _find_overlapping_positions(sequence: bytearray, pattern: bytes) -> list[int]:
69+
70+
def _find_overlapping_positions(sequence: bytes, pattern: bytes) -> list[int]:
5771
"""
58-
Find all overlapping occurrences of pattern in sequence.
72+
Return every overlapping occurrence of pattern in sequence.
5973
6074
Example:
6175
sequence = b"AAAA"
6276
pattern = b"AA"
6377
result = [0, 1, 2]
6478
"""
6579

66-
# Preserve baseline behavior:
67-
# an empty pattern matches every position from 0 through len(sequence).
80+
# Match the baseline's empty-pattern behavior.
6881
if not pattern:
6982
return list(range(len(sequence) + 1))
7083

@@ -76,74 +89,107 @@ def _find_overlapping_positions(sequence: bytearray, pattern: bytes) -> list[int
7689

7790
while True:
7891
pos = find(pattern, start)
92+
7993
if pos < 0:
8094
return positions
8195

8296
append(pos)
8397
start = pos + 1
8498

8599

86-
def _search_batch(batch: list[Record], pattern: bytes) -> list[SearchResult]:
100+
def _search_batch(
101+
mm: mmap.mmap,
102+
spans: list[Span],
103+
pattern: bytes,
104+
) -> list[SearchResult]:
87105
"""
88106
Worker function.
89107
90-
Each worker receives a batch of records to reduce ThreadPoolExecutor
91-
scheduling overhead for FASTA files with many small records.
108+
Each worker processes a batch of records. Batching is important for a file
109+
with ~10k sequences because submitting 10k individual futures is wasteful.
92110
"""
93111

94-
return [
95-
(index, record_id, _find_overlapping_positions(sequence, pattern))
96-
for index, record_id, sequence in batch
97-
]
112+
results: list[SearchResult] = []
113+
append_result = results.append
114+
delete_bytes = _DELETE_SEQUENCE_BYTES
115+
116+
for index, start, end in spans:
117+
header_end = mm.find(b"\n", start, end)
118+
119+
if header_end < 0:
120+
# Header-only record.
121+
record_id = mm[start + 1 : end].strip().decode("ascii")
122+
sequence = b""
123+
else:
124+
record_id = mm[start + 1 : header_end].strip().decode("ascii")
125+
126+
# This does sequence normalization in C:
127+
# remove line breaks and spaces from the sequence portion.
128+
sequence = mm[header_end + 1 : end].translate(None, delete_bytes)
129+
130+
positions = _find_overlapping_positions(sequence, pattern)
131+
append_result((index, record_id, positions))
132+
133+
return results
98134

99135

100136
def find_matches(
101137
fasta_path: Pathish,
102138
pattern: bytes,
103139
*,
104140
max_workers: int | None = None,
141+
batch_records: int = 128,
142+
batch_bytes: int = 8 << 20, # 8 MiB
105143
max_pending_batches: int | None = None,
106-
batch_records: int = 64,
107-
batch_bytes: int = 8 << 20, # 8 MiB of sequence data
108144
) -> list[tuple[str, list[int]]]:
109145
"""
110146
Find every FASTA record whose sequence contains `pattern`.
111147
112148
Returns:
113149
[(record_id, [positions...]), ...]
114150
115-
Threaded design:
116-
- main thread parses the FASTA file
117-
- worker threads search records in parallel
118-
- main thread collects results and emits them in original file order
151+
Tuned for roughly:
152+
- 512 MB input
153+
- ~10,145 records
154+
- free-threaded CPython
119155
120-
This is designed for free-threaded Python. On normal GIL-enabled CPython,
121-
CPU-bound speedup may be much smaller.
156+
The defaults create approximately 60-90 tasks for your file size, rather
157+
than 10,145 tiny tasks.
122158
"""
123159

124160
pattern = bytes(pattern)
125161

162+
# Preserve the baseline's assumption that the pattern is ASCII text.
163+
pattern.decode("ascii")
164+
126165
if max_workers is None:
127-
max_workers = os.cpu_count() or 1
166+
max_workers = _default_worker_count()
167+
128168
if max_workers < 1:
129169
raise ValueError("max_workers must be positive")
130170

131-
if max_pending_batches is None:
132-
max_pending_batches = max_workers * 2
133-
if max_pending_batches < 1:
134-
raise ValueError("max_pending_batches must be positive")
135-
136171
if batch_records < 1:
137172
raise ValueError("batch_records must be positive")
173+
138174
if batch_bytes < 1:
139175
raise ValueError("batch_bytes must be positive")
140176

177+
if max_pending_batches is None:
178+
max_pending_batches = max_workers * 4
179+
180+
if max_pending_batches < 1:
181+
raise ValueError("max_pending_batches must be positive")
182+
183+
size = os.path.getsize(fasta_path)
184+
185+
if size == 0:
186+
return []
187+
141188
matches: list[tuple[str, list[int]]] = []
142189

143190
# Completed records waiting to be emitted in file order.
144191
ready: dict[int, tuple[str, list[int]]] = {}
145192

146-
pending: set[Future[list[SearchResult]]] = set()
147193
next_to_emit = 0
148194

149195
def collect(done: set[Future[list[SearchResult]]]) -> None:
@@ -153,7 +199,7 @@ def collect(done: set[Future[list[SearchResult]]]) -> None:
153199
for index, record_id, positions in future.result():
154200
ready[index] = (record_id, positions)
155201

156-
# Emit only when the next file-order record is available.
202+
# Preserve file order even when worker batches complete out of order.
157203
while next_to_emit in ready:
158204
record_id, positions = ready.pop(next_to_emit)
159205

@@ -162,29 +208,42 @@ def collect(done: set[Future[list[SearchResult]]]) -> None:
162208

163209
next_to_emit += 1
164210

165-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
166-
batch: list[Record] = []
167-
batch_size = 0
168-
169-
for record in _iter_fasta_records(fasta_path):
170-
batch.append(record)
171-
batch_size += len(record[2])
211+
with open(fasta_path, "rb") as file:
212+
with mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ) as mm:
213+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
214+
pending: set[Future[list[SearchResult]]] = set()
172215

173-
if len(batch) >= batch_records or batch_size >= batch_bytes:
174-
pending.add(executor.submit(_search_batch, batch, pattern))
175-
batch = []
216+
batch: list[Span] = []
176217
batch_size = 0
177218

178-
# Backpressure: do not let the parser enqueue the whole file.
179-
if len(pending) >= max_pending_batches:
180-
done, pending = wait(pending, return_when=FIRST_COMPLETED)
181-
collect(done)
219+
for span in _iter_record_spans(mm, size):
220+
_, start, end = span
221+
222+
batch.append(span)
223+
batch_size += end - start
182224

183-
if batch:
184-
pending.add(executor.submit(_search_batch, batch, pattern))
225+
if len(batch) >= batch_records or batch_size >= batch_bytes:
226+
pending.add(executor.submit(_search_batch, mm, batch, pattern))
185227

186-
while pending:
187-
done, pending = wait(pending, return_when=FIRST_COMPLETED)
188-
collect(done)
228+
batch = []
229+
batch_size = 0
230+
231+
# Backpressure. Avoid queueing unbounded work.
232+
if len(pending) >= max_pending_batches:
233+
done, pending = wait(
234+
pending,
235+
return_when=FIRST_COMPLETED,
236+
)
237+
collect(done)
238+
239+
if batch:
240+
pending.add(executor.submit(_search_batch, mm, batch, pattern))
241+
242+
while pending:
243+
done, pending = wait(
244+
pending,
245+
return_when=FIRST_COMPLETED,
246+
)
247+
collect(done)
189248

190249
return matches

0 commit comments

Comments
 (0)