Skip to content

Commit 702c828

Browse files
committed
threads
1 parent 734bc9b commit 702c828

1 file changed

Lines changed: 56 additions & 28 deletions

File tree

rounds/3_dna/solution.py

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22

33
from __future__ import annotations
44

5+
import os
6+
from concurrent.futures import ThreadPoolExecutor
7+
58
import numpy as np
69

710
_NEWLINE = b"\n"
11+
_MAX_WORKERS = 12
812

913

1014
def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]:
11-
"""Find every FASTA record whose sequence contains ``pattern``.
12-
13-
This version assumes the benchmark-sized generated FASTA input: ASCII
14-
headers, DNA sequence lines separated by ``\n``, and no whitespace inside
15-
sequence lines besides those newlines.
16-
"""
15+
"""Find every FASTA record whose sequence contains ``pattern``."""
1716
if not pattern:
1817
return []
1918

@@ -23,28 +22,57 @@ def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]
2322
with open(fasta_path, "rb") as file:
2423
data = file.read()
2524

26-
matches: list[tuple[str, list[int]]] = []
27-
for record in data.split(b">")[1:]:
28-
record_id, _, wrapped_sequence = record.partition(_NEWLINE)
29-
sequence = wrapped_sequence.replace(_NEWLINE, b"")
30-
sequence_len = len(sequence)
31-
if sequence_len < pattern_len:
32-
continue
33-
34-
sequence_values = np.frombuffer(sequence, dtype=np.uint8)
35-
positions_mask = (
36-
sequence_values[: sequence_len - pattern_len + 1] == pattern_values[0]
25+
records = data.split(b">")[1:]
26+
worker_count = min(_MAX_WORKERS, os.cpu_count() or 1, len(records))
27+
if worker_count <= 1:
28+
return _scan_records(records, pattern_values, pattern_len)
29+
30+
chunk_size = (len(records) + worker_count - 1) // worker_count
31+
chunks = [
32+
records[start : start + chunk_size]
33+
for start in range(0, len(records), chunk_size)
34+
]
35+
with ThreadPoolExecutor(max_workers=worker_count) as executor:
36+
groups = executor.map(
37+
_scan_records,
38+
chunks,
39+
[pattern_values] * len(chunks),
40+
[pattern_len] * len(chunks),
3741
)
38-
for pattern_index in range(1, pattern_len):
39-
positions_mask &= (
40-
sequence_values[
41-
pattern_index : sequence_len - pattern_len + 1 + pattern_index
42-
]
43-
== pattern_values[pattern_index]
44-
)
45-
46-
positions = np.nonzero(positions_mask)[0]
47-
if positions.size:
48-
matches.append((record_id.decode("ascii"), positions.tolist()))
4942

43+
return [match for group in groups for match in group]
44+
45+
46+
def _scan_records(
47+
records: list[bytes], pattern_values: np.ndarray, pattern_len: int
48+
) -> list[tuple[str, list[int]]]:
49+
matches: list[tuple[str, list[int]]] = []
50+
for record in records:
51+
match = _scan_record(record, pattern_values, pattern_len)
52+
if match is not None:
53+
matches.append(match)
5054
return matches
55+
56+
57+
def _scan_record(
58+
record: bytes, pattern_values: np.ndarray, pattern_len: int
59+
) -> tuple[str, list[int]] | None:
60+
record_id, _, wrapped_sequence = record.partition(_NEWLINE)
61+
sequence = wrapped_sequence.replace(_NEWLINE, b"")
62+
sequence_len = len(sequence)
63+
if sequence_len < pattern_len:
64+
return None
65+
66+
sequence_values = np.frombuffer(sequence, dtype=np.uint8)
67+
candidate_count = sequence_len - pattern_len + 1
68+
positions_mask = sequence_values[:candidate_count] == pattern_values[0]
69+
for pattern_index in range(1, pattern_len):
70+
positions_mask &= (
71+
sequence_values[pattern_index : candidate_count + pattern_index]
72+
== pattern_values[pattern_index]
73+
)
74+
75+
positions = np.nonzero(positions_mask)[0]
76+
if positions.size:
77+
return record_id.decode("ascii"), positions.tolist()
78+
return None

0 commit comments

Comments
 (0)