Skip to content

Commit 55c3728

Browse files
committed
threads 2
1 parent 702c828 commit 55c3728

1 file changed

Lines changed: 30 additions & 14 deletions

File tree

rounds/3_dna/solution.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,17 @@ def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]
1616
if not pattern:
1717
return []
1818

19-
pattern_values = np.frombuffer(pattern, dtype=np.uint8)
2019
pattern_len = len(pattern)
20+
pattern_prefix = np.frombuffer(pattern[:4], dtype=np.uint32)[0]
21+
pattern_suffix = np.frombuffer(pattern[4:], dtype=np.uint32)[0]
2122

2223
with open(fasta_path, "rb") as file:
2324
data = file.read()
2425

2526
records = data.split(b">")[1:]
2627
worker_count = min(_MAX_WORKERS, os.cpu_count() or 1, len(records))
2728
if worker_count <= 1:
28-
return _scan_records(records, pattern_values, pattern_len)
29+
return _scan_records(records, pattern_prefix, pattern_suffix, pattern_len)
2930

3031
chunk_size = (len(records) + worker_count - 1) // worker_count
3132
chunks = [
@@ -36,43 +37,58 @@ def find_matches(fasta_path: str, pattern: bytes) -> list[tuple[str, list[int]]]
3637
groups = executor.map(
3738
_scan_records,
3839
chunks,
39-
[pattern_values] * len(chunks),
40+
[pattern_prefix] * len(chunks),
41+
[pattern_suffix] * len(chunks),
4042
[pattern_len] * len(chunks),
4143
)
4244

4345
return [match for group in groups for match in group]
4446

4547

4648
def _scan_records(
47-
records: list[bytes], pattern_values: np.ndarray, pattern_len: int
49+
records: list[bytes],
50+
pattern_prefix: np.uint32,
51+
pattern_suffix: np.uint32,
52+
pattern_len: int,
4853
) -> list[tuple[str, list[int]]]:
4954
matches: list[tuple[str, list[int]]] = []
5055
for record in records:
51-
match = _scan_record(record, pattern_values, pattern_len)
56+
match = _scan_record(record, pattern_prefix, pattern_suffix, pattern_len)
5257
if match is not None:
5358
matches.append(match)
5459
return matches
5560

5661

5762
def _scan_record(
58-
record: bytes, pattern_values: np.ndarray, pattern_len: int
63+
record: bytes,
64+
pattern_prefix: np.uint32,
65+
pattern_suffix: np.uint32,
66+
pattern_len: int,
5967
) -> tuple[str, list[int]] | None:
6068
record_id, _, wrapped_sequence = record.partition(_NEWLINE)
6169
sequence = wrapped_sequence.replace(_NEWLINE, b"")
6270
sequence_len = len(sequence)
6371
if sequence_len < pattern_len:
6472
return None
6573

66-
sequence_values = np.frombuffer(sequence, dtype=np.uint8)
6774
candidate_count = sequence_len - pattern_len + 1
68-
positions_mask = sequence_values[:candidate_count] == pattern_values[0]
69-
for pattern_index in range(1, pattern_len):
70-
positions_mask &= (
71-
sequence_values[pattern_index : candidate_count + pattern_index]
72-
== pattern_values[pattern_index]
73-
)
75+
prefixes = np.ndarray(
76+
shape=(candidate_count,),
77+
dtype=np.uint32,
78+
buffer=sequence,
79+
strides=(1,),
80+
)
81+
candidates = np.nonzero(prefixes == pattern_prefix)[0]
82+
if not candidates.size:
83+
return None
7484

75-
positions = np.nonzero(positions_mask)[0]
85+
suffixes = np.ndarray(
86+
shape=(candidate_count,),
87+
dtype=np.uint32,
88+
buffer=memoryview(sequence)[4:],
89+
strides=(1,),
90+
)
91+
positions = candidates[suffixes[candidates] == pattern_suffix]
7692
if positions.size:
7793
return record_id.decode("ascii"), positions.tolist()
7894
return None

0 commit comments

Comments
 (0)