Skip to content

Commit 9cd65cc

Browse files
Merge branch 'develop' into feature/karen_bacterial_wrapper
2 parents 0f0bf1b + 670bc10 commit 9cd65cc

28 files changed

Lines changed: 4758 additions & 151 deletions

.gitignore

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,31 @@
11
# Ignore filetypes
22
*.pyc
3+
*.pyo
4+
*.pyd
5+
__pycache__/
6+
7+
# Virtual environments
38
/python2env/
9+
/.venv/
10+
/venv/
11+
/env/
12+
13+
# IDEs
414
/.ipynb_checkpoints/
515
/.vscode/
6-
/.idea/
16+
/.idea/
17+
18+
# Test & coverage artifacts
19+
.coverage
20+
.coverage.*
21+
htmlcov/
22+
.pytest_cache/
23+
24+
# NEAT log files
25+
*.log
26+
27+
# Build / packaging
28+
dist/
29+
build/
30+
*.egg-info/
31+
*.egg

neat/read_simulator/utils/bed_func.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,29 @@
2121

2222
def intersect_regions(mutation_regions: list, block_tuple: tuple[int, int], default_value: float) -> list:
2323
"""
24+
Clips each mutation region to the block window and returns only the overlapping
25+
sub-intervals, preserving each region's mutation rate.
26+
2427
Our assumption here is that mutation regions is a continuous list, such that
2528
for each region, the end of the previous region is the start of the next region,
2629
and there are no gaps. This should be true of anything generated from parse_beds, but
2730
needs some tests to verify
2831
"""
2932
ret_list = []
3033
block_start, block_end = block_tuple
31-
for i in range(len(mutation_regions)):
32-
region = mutation_regions[i]
33-
if region[0] <= block_start < region[1]:
34-
# We found the first region covering the block
35-
if block_end <= region[1]:
36-
# If the block spans the entire region, we have a special case
37-
ret_list.append((block_start, block_end, region[2]))
38-
# nothing more to do
39-
return ret_list
40-
ret_list.append((block_start, region[1], region[2]))
41-
elif region[0] <= block_end < region[1]:
42-
# We found the last region covering the block
43-
ret_list.append((region[0], block_end, region[2]))
44-
# nothing more to do
45-
return ret_list
46-
# If we haven't returned yet, then we did not find the end in our mutations list
47-
ret_list.append((mutation_regions[-1][1], block_end, default_value))
34+
for region in mutation_regions:
35+
overlap_start = max(region[0], block_start)
36+
overlap_end = min(region[1], block_end)
37+
if overlap_start < overlap_end:
38+
ret_list.append((overlap_start, overlap_end, region[2]))
39+
40+
if not ret_list:
41+
# Block is entirely outside all provided regions
42+
ret_list.append((block_start, block_end, default_value))
43+
elif ret_list[-1][1] < block_end:
44+
# Block extends past the last region; fill the tail with the default rate
45+
ret_list.append((ret_list[-1][1], block_end, default_value))
46+
4847
return ret_list
4948

5049
def parse_beds(options: Options, ref_keys_counts: dict) -> list:

neat/read_simulator/utils/generate_reads.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@ def generate_reads(
183183
# _LOG.info(f'Sampling reads for thread {thread_index}...')
184184
start_time = time.time()
185185

186+
if len(reference) < options.read_len:
187+
_LOG.warning(
188+
f"Contig '{contig_name}' (length {len(reference)}) is shorter than read_len "
189+
f"({options.read_len}). Skipping contig."
190+
)
191+
return []
192+
186193
# _LOG.debug("Covering dataset.")
187194
t = time.time()
188195
reads = cover_dataset(

neat/read_simulator/utils/generate_variants.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ def generate_variants(
7777
for variant in input_variants.contig_variants[variant_location]:
7878
return_variants.add_variant(variant)
7979

80-
# pase out the mutation rates
81-
mutation_rates = np.array([x[2] for x in mutation_rate_regions])
80+
# pase out the mutation rates; substitute None with the model average
81+
mutation_rates = np.array([x[2] if x[2] is not None else mutation_model.avg_mut_rate
82+
for x in mutation_rate_regions])
8283

8384
# Trying to use a random window to keep memory under control. May need to adjust this number.
8485
max_window_size = 1000
@@ -114,13 +115,12 @@ def generate_variants(
114115
# _LOG.info(f'Planning to add {how_many_mutations} mutations. The final number may be less.')
115116

116117
while how_many_mutations > 0:
117-
# Pick a region based on the mutation rates
118-
# (default is one rate for the whole chromosome, so this will be trivial in that case
119-
# for this selection, we'll normalize the mutation rates
120-
probability_rates = mutation_rates / sum(mutation_rates)
121118
# We need to intersect our chosen mutation region with our block
122119
local_mut_regions = bed_func.intersect_regions(mutation_rate_regions, (ref_start, ref_start + len(reference)), options.mutation_rate)
123120
# For no input mutation regions bed, this will return the entire sequence.
121+
# Build probability weights from the intersected regions so the lengths always match.
122+
local_rates = np.array([r[2] if r[2] is not None else mutation_model.avg_mut_rate for r in local_mut_regions])
123+
probability_rates = local_rates / sum(local_rates)
124124
mut_region = options.rng.choice(a=local_mut_regions, p=probability_rates)
125125
mut_region_offset = (int(mut_region[0]-ref_start), int(mut_region[1]-ref_start), mut_region[2])
126126

neat/read_simulator/utils/vcf_func.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,19 +161,23 @@ def parse_input_vcf(
161161
# Retrieve the GT from the first sample in the record
162162
genotype = retrieve_genotype(record)
163163

164-
elif "WP" in [x.split('=') for x in record[7].split(';')]:
164+
elif "WP" in [x.split('=')[0] for x in record[7].split(';') if '=' in x]:
165165
"""
166166
"WP" is the legacy code NEAT used for genotype it added. It was found in the INFO field.
167167
We're just going to make a sample column in this version of NEAT
168168
The logic of the statement is split the info field on ';' which is used as a divider in that field.
169169
Most but not all fields also have an '=', so split there too, then look for "WP"
170170
"""
171171
format_column = f"GT:{record[8]}"
172-
for record in record[7].split(';'):
173-
if record.startswith('WP'):
174-
genotype = record.split('=')[1].replace('/', '|').split('|')
172+
sample_field = record[9]
173+
for info_item in record[7].split(';'):
174+
if info_item.startswith('WP') and '=' in info_item:
175+
genotype = info_item.split('=')[1].replace('/', '|').split('|')
175176
genotype = np.array([int(x) for x in genotype])
176-
normal_sample_field = f"{get_genotype_string(genotype)}:{record[9]}"
177+
normal_sample_field = f"{get_genotype_string(genotype)}:{sample_field}"
178+
elif info_item.startswith('WP'):
179+
_LOG.error(f'Malformed WP field in INFO (missing value): {record[7]}')
180+
sys.exit(1)
177181

178182
else:
179183
format_column = 'GT:' + record[8]
@@ -182,20 +186,22 @@ def parse_input_vcf(
182186
gt_field = get_genotype_string(genotype)
183187
normal_sample_field = f'{gt_field}:{record[9]}'
184188

185-
elif "WP" in [x.split('=') for x in record[7].split(';')]:
189+
elif "WP" in [x.split('=')[0] for x in record[7].split(';') if '=' in x]:
186190
"""
187191
"WP" is the legacy code NEAT used for genotype it added. It was found in the INFO field.
188192
We're just going to make a sample column in this version of NEAT
189193
The logic of the statement is split the info field on ';' which is used as a divider in that field.
190194
Most but not all fields also have an '=', so split there too, then look for "WP"
191195
"""
192196
format_column = "GT"
193-
info_split = record[7].split(';')
194-
for record in info_split:
195-
if record.startswith('WP'):
196-
genotype = record.split('=')[1].replace('/', '|').split('|')
197+
for info_item in record[7].split(';'):
198+
if info_item.startswith('WP') and '=' in info_item:
199+
genotype = info_item.split('=')[1].replace('/', '|').split('|')
197200
genotype = np.array([int(x) for x in genotype])
198201
normal_sample_field = get_genotype_string(genotype)
202+
elif info_item.startswith('WP'):
203+
_LOG.error(f'Malformed WP field in INFO (missing value): {record[7]}')
204+
sys.exit(1)
199205

200206
else:
201207
# If there was no format column, there's no sample column, so we'll generate one

neat/variants/contig_variants.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def check_if_del(self, other):
4949

5050
def check_if_ins(self, other):
5151
for insert in self.all_ins:
52-
if np.array_equal(other.genotype, insert.genotype) and insert.contains(other):
52+
if np.array_equal(other.genotype, insert.genotype) and insert.contains(other.position1):
5353
return insert
5454
return None
5555

@@ -150,11 +150,11 @@ def get_sample_info(variant):
150150
return get_genotype_string(variant.genotype)
151151

152152
def remove_variant(self, variant):
153-
if variant.position in self.variant_locations:
154-
if variant in self.contig_variants[variant.position]:
155-
self.contig_variants[variant.position].remove(variant)
156-
if not self.contig_variants[variant.position]:
157-
self.variant_locations.remove(variant.position)
153+
if variant.position1 in self.variant_locations:
154+
if variant in self.contig_variants[variant.position1]:
155+
self.contig_variants[variant.position1].remove(variant)
156+
if not self.contig_variants[variant.position1]:
157+
self.variant_locations.remove(variant.position1)
158158

159159
def __getitem__(self, input_location: int) -> list:
160160
"""

tests/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import logging
2+
import pytest
3+
4+
5+
@pytest.fixture(autouse=True)
6+
def _isolate_neat_logging():
7+
"""
8+
Close and remove any FileHandlers attached to NEAT loggers before each test.
9+
Prevents 'ValueError: I/O operation on closed file' errors when a FileHandler
10+
from a previous test is still attached after its underlying file is closed.
11+
Propagation is left intact so caplog can capture NEAT log output.
12+
"""
13+
def _close_file_handlers(logger):
14+
for h in list(logger.handlers):
15+
if isinstance(h, logging.FileHandler):
16+
logger.removeHandler(h)
17+
try:
18+
h.close()
19+
except Exception:
20+
pass
21+
22+
for name, logger in list(logging.Logger.manager.loggerDict.items()):
23+
if (name == "neat" or name.startswith("neat.")) and isinstance(logger, logging.Logger):
24+
_close_file_handlers(logger)
25+
26+
yield
27+
28+
for name, logger in list(logging.Logger.manager.loggerDict.items()):
29+
if (name == "neat" or name.startswith("neat.")) and isinstance(logger, logging.Logger):
30+
_close_file_handlers(logger)

tests/test_cli/test_basic_cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_basic_cli():
3636
stdout=subprocess.PIPE,
3737
stderr=subprocess.PIPE,
3838
text=True,
39+
cwd=str(td),
3940
)
4041
assert proc.returncode == 0, f"STDERR:\n{proc.stderr}"
4142
assert out.exists()

tests/test_models/test_error_and_mut_models.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,8 @@ def test_mutation_model_generate_snv_trinuc():
4141
assert snv.alt in ["A", "C", "G", "T"]
4242

4343

44-
def test_sequencing_error_model_zero_error_returns_none_or_empty():
45-
"""
46-
avg_seq_error == 0 should yield no errors.
47-
"""
48-
rng = default_rng(4)
49-
sem = SequencingErrorModel(avg_seq_error=0.0)
50-
ref = SeqRecord(Seq("A" * 40), id="chr1")
51-
quals = np.array([40] * 40, dtype=int)
52-
result = sem.get_sequencing_errors(
53-
padding=20,
54-
reference_segment=ref,
55-
quality_scores=quals,
56-
rng=rng,
57-
)
58-
if isinstance(result, tuple):
59-
introduced, pad = result
60-
assert introduced == []
61-
assert pad >= 0
62-
else:
63-
assert result == []
44+
# test_sequencing_error_model_zero_error_returns_none_or_empty removed:
45+
# duplicate of test_error_models.py::test_sem_zero_error_rate_returns_empty
6446

6547

6648
def test_traditional_quality_model_shapes_and_range():
@@ -135,16 +117,8 @@ def test_mutation_model_snv_does_not_keep_reference_base():
135117
assert snv.alt != central
136118

137119

138-
def test_traditional_quality_model_reproducible_with_seed():
139-
"""Quality model should be deterministic given the same RNG state."""
140-
rng1 = default_rng(8)
141-
rng2 = default_rng(8)
142-
qm = TraditionalQualityModel(average_error=0.01)
143-
144-
qs1 = qm.get_quality_scores(model_read_length=151, length=100, rng=rng1)
145-
qs2 = qm.get_quality_scores(model_read_length=151, length=100, rng=rng2)
146-
147-
assert np.array_equal(qs1, qs2)
120+
# test_traditional_quality_model_reproducible_with_seed removed:
121+
# duplicate of test_error_models.py::test_tqm_get_quality_scores_reproducible
148122

149123

150124
def test_sequencing_error_model_reproducible_with_seed():

0 commit comments

Comments
 (0)