Skip to content

Commit 3d4a7c8

Browse files
Merge pull request #262 from ncsa/fix/generate-variants-probability-rates-mismatch
Fix probability_rates mismatch with local_mut_regions in generate_var…
2 parents 4f09fbe + e16020a commit 3d4a7c8

3 files changed

Lines changed: 231 additions & 23 deletions

File tree

neat/read_simulator/utils/bed_func.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,29 @@
2121

2222
def intersect_regions(mutation_regions: list, block_tuple: tuple[int, int], default_value: float) -> list:
2323
"""
24+
Clips each mutation region to the block window and returns only the overlapping
25+
sub-intervals, preserving each region's mutation rate.
26+
2427
Our assumption here is that mutation regions is a continuous list, such that
2528
for each region, the end of the previous region is the start of the next region,
2629
and there are no gaps. This should be true of anything generated from parse_beds, but
2730
needs some tests to verify
2831
"""
2932
ret_list = []
3033
block_start, block_end = block_tuple
31-
for i in range(len(mutation_regions)):
32-
region = mutation_regions[i]
33-
if region[0] <= block_start < region[1]:
34-
# We found the first region covering the block
35-
if block_end <= region[1]:
36-
# If the block spans the entire region, we have a special case
37-
ret_list.append((block_start, block_end, region[2]))
38-
# nothing more to do
39-
return ret_list
40-
ret_list.append((block_start, region[1], region[2]))
41-
elif region[0] <= block_end < region[1]:
42-
# We found the last region covering the block
43-
ret_list.append((region[0], block_end, region[2]))
44-
# nothing more to do
45-
return ret_list
46-
# If we haven't returned yet, then we did not find the end in our mutations list
47-
ret_list.append((mutation_regions[-1][1], block_end, default_value))
34+
for region in mutation_regions:
35+
overlap_start = max(region[0], block_start)
36+
overlap_end = min(region[1], block_end)
37+
if overlap_start < overlap_end:
38+
ret_list.append((overlap_start, overlap_end, region[2]))
39+
40+
if not ret_list:
41+
# Block is entirely outside all provided regions
42+
ret_list.append((block_start, block_end, default_value))
43+
elif ret_list[-1][1] < block_end:
44+
# Block extends past the last region; fill the tail with the default rate
45+
ret_list.append((ret_list[-1][1], block_end, default_value))
46+
4847
return ret_list
4948

5049
def parse_beds(options: Options, ref_keys_counts: dict) -> list:

neat/read_simulator/utils/generate_variants.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ def generate_variants(
7777
for variant in input_variants.contig_variants[variant_location]:
7878
return_variants.add_variant(variant)
7979

80-
# pase out the mutation rates
81-
mutation_rates = np.array([x[2] for x in mutation_rate_regions])
80+
# pase out the mutation rates; substitute None with the model average
81+
mutation_rates = np.array([x[2] if x[2] is not None else mutation_model.avg_mut_rate
82+
for x in mutation_rate_regions])
8283

8384
# Trying to use a random window to keep memory under control. May need to adjust this number.
8485
max_window_size = 1000
@@ -114,13 +115,12 @@ def generate_variants(
114115
# _LOG.info(f'Planning to add {how_many_mutations} mutations. The final number may be less.')
115116

116117
while how_many_mutations > 0:
117-
# Pick a region based on the mutation rates
118-
# (default is one rate for the whole chromosome, so this will be trivial in that case
119-
# for this selection, we'll normalize the mutation rates
120-
probability_rates = mutation_rates / sum(mutation_rates)
121118
# We need to intersect our chosen mutation region with our block
122119
local_mut_regions = bed_func.intersect_regions(mutation_rate_regions, (ref_start, ref_start + len(reference)), options.mutation_rate)
123120
# For no input mutation regions bed, this will return the entire sequence.
121+
# Build probability weights from the intersected regions so the lengths always match.
122+
local_rates = np.array([r[2] if r[2] is not None else mutation_model.avg_mut_rate for r in local_mut_regions])
123+
probability_rates = local_rates / sum(local_rates)
124124
mut_region = options.rng.choice(a=local_mut_regions, p=probability_rates)
125125
mut_region_offset = (int(mut_region[0]-ref_start), int(mut_region[1]-ref_start), mut_region[2])
126126

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""
2+
Regression tests for the probability_rates / local_mut_regions mismatch and
3+
the intersect_regions algorithm fix.
4+
5+
Both fixed on branch fix/generate-variants-probability-rates-mismatch.
6+
7+
Root cause (two related bugs):
8+
9+
1. intersect_regions dropped all "middle" mutation regions (those fully
10+
contained within the block) and appended a zero-length fallback when
11+
block_end == last_region_end, returning M items instead of N.
12+
13+
2. generate_variants built probability_rates from the original N-item
14+
mutation_rate_regions list, then called rng.choice(a=local_mut_regions,
15+
p=probability_rates) where len(local_mut_regions) == M. When M != N
16+
(always true for N >= 3 in practice) this raised:
17+
ValueError: 'a' and 'p' must have same length
18+
19+
Fixes:
20+
intersect_regions — rewritten using overlap arithmetic so every region
21+
that intersects the block is included and the tail fallback only fires
22+
when the block genuinely extends past all regions.
23+
24+
generate_variants — probability_rates now derived from local_mut_regions
25+
(after the intersect) so lengths always match; None rates are
26+
substituted with mutation_model.avg_mut_rate.
27+
28+
Note: the single-region case and basic generate_variants behaviour are already
29+
covered on the feature/claude-assisted-tests branch (28 tests). Tests here
30+
cover only the multi-region and None-rate scenarios that were broken.
31+
"""
32+
import pytest
33+
from Bio.Seq import Seq
34+
from Bio.SeqRecord import SeqRecord
35+
36+
from neat.models import MutationModel
37+
from neat.read_simulator.utils.bed_func import intersect_regions
38+
from neat.read_simulator.utils.generate_variants import generate_variants
39+
from neat.read_simulator.utils.options import Options
40+
from neat.variants import ContigVariants
41+
42+
43+
# ---------------------------------------------------------------------------
44+
# Helpers
45+
# ---------------------------------------------------------------------------
46+
47+
def _make_ref(length: int = 400) -> SeqRecord:
48+
seq = "ACGT" * (length // 4)
49+
return SeqRecord(Seq(seq), id="chr1", name="chr1", description="")
50+
51+
52+
def _make_opts(seed: int = 0, min_mutations: int = 1) -> Options:
53+
opts = Options(rng_seed=seed)
54+
opts.ploidy = 2
55+
opts.min_mutations = min_mutations
56+
opts.mutation_rate = None
57+
opts.mutation_bed = None
58+
return opts
59+
60+
61+
def _run_gv(regions, ref=None, seed=0, min_mutations=1):
62+
if ref is None:
63+
ref = _make_ref()
64+
opts = _make_opts(seed=seed, min_mutations=min_mutations)
65+
model = MutationModel()
66+
model.rng = opts.rng
67+
return generate_variants(
68+
reference=ref,
69+
ref_start=0,
70+
mutation_rate_regions=regions,
71+
input_variants=ContigVariants(),
72+
mutation_model=model,
73+
options=opts,
74+
max_qual_score=40,
75+
)
76+
77+
78+
# ===========================================================================
79+
# intersect_regions — regression for the algorithm rewrite
80+
# ===========================================================================
81+
82+
def test_intersect_three_regions_all_included():
83+
"""
84+
Three regions spanning the block exactly — all three must appear in output.
85+
The old algorithm dropped the middle region and returned only 2 items.
86+
"""
87+
regions = [(0, 133, 0.01), (133, 266, 0.02), (266, 400, 0.05)]
88+
result = intersect_regions(regions, (0, 400), 0.0)
89+
assert len(result) == 3
90+
assert result[0] == (0, 133, 0.01)
91+
assert result[1] == (133, 266, 0.02)
92+
assert result[2] == (266, 400, 0.05)
93+
94+
95+
def test_intersect_four_regions_all_included():
96+
"""Four regions, block exactly spans all — all four returned."""
97+
regions = [(0, 100, 0.01), (100, 200, 0.02), (200, 300, 0.03), (300, 400, 0.05)]
98+
result = intersect_regions(regions, (0, 400), 0.0)
99+
assert len(result) == 4
100+
101+
102+
def test_intersect_block_end_equals_last_region_end_no_fallback():
103+
"""
104+
When block_end == last region end the old code appended a zero-length
105+
fallback (last_end, block_end, default). The new code must NOT add it.
106+
"""
107+
regions = [(0, 200, 0.01), (200, 400, 0.05)]
108+
result = intersect_regions(regions, (0, 400), 0.0)
109+
# No zero-length entry
110+
assert all(r[0] < r[1] for r in result)
111+
112+
113+
def test_intersect_block_partially_overlaps_middle_region():
114+
"""Block (150, 350) overlaps parts of all three input regions."""
115+
regions = [(0, 200, 0.01), (200, 300, 0.02), (300, 400, 0.05)]
116+
result = intersect_regions(regions, (150, 350), 0.0)
117+
assert (150, 200, 0.01) in result
118+
assert (200, 300, 0.02) in result
119+
assert (300, 350, 0.05) in result
120+
121+
122+
def test_intersect_result_is_contiguous():
123+
"""Output sub-intervals must be contiguous (each end == next start)."""
124+
regions = [(0, 133, 0.01), (133, 266, 0.02), (266, 400, 0.05)]
125+
result = intersect_regions(regions, (0, 400), 0.0)
126+
for i in range(len(result) - 1):
127+
assert result[i][1] == result[i + 1][0]
128+
129+
130+
def test_intersect_result_covers_full_block():
131+
"""First item starts at block_start, last item ends at block_end."""
132+
regions = [(0, 133, 0.01), (133, 266, 0.02), (266, 400, 0.05)]
133+
result = intersect_regions(regions, (0, 400), 0.0)
134+
assert result[0][0] == 0
135+
assert result[-1][1] == 400
136+
137+
138+
def test_intersect_block_outside_all_regions_returns_default():
139+
"""Block with no overlap with any region → single fallback entry."""
140+
regions = [(0, 100, 0.01), (100, 200, 0.05)]
141+
result = intersect_regions(regions, (300, 500), 0.99)
142+
assert result == [(300, 500, 0.99)]
143+
144+
145+
# ===========================================================================
146+
# generate_variants — crash regression with multiple mutation rate regions
147+
# ===========================================================================
148+
149+
def test_three_mut_regions_no_crash():
150+
"""
151+
Primary crash regression: 3 mutation rate regions over a 400 bp reference.
152+
153+
Before the fix:
154+
intersect_regions returned 2 items; probability_rates had 3 →
155+
ValueError: 'a' and 'p' must have same length
156+
"""
157+
result = _run_gv([(0, 133, 0.01), (133, 266, 0.02), (266, 400, 0.05)])
158+
assert isinstance(result, ContigVariants)
159+
160+
161+
def test_two_mut_regions_no_crash():
162+
"""Two regions — was silently broken (zero-length fallback as second region)."""
163+
result = _run_gv([(0, 200, 0.01), (200, 400, 0.05)])
164+
assert isinstance(result, ContigVariants)
165+
166+
167+
def test_four_mut_regions_no_crash():
168+
"""Four regions — more aggressively exercises the fix."""
169+
result = _run_gv([(0, 100, 0.01), (100, 200, 0.02), (200, 300, 0.03), (300, 400, 0.05)])
170+
assert isinstance(result, ContigVariants)
171+
172+
173+
def test_none_rate_region_no_crash():
174+
"""
175+
None rate (from recalibrate_mutation_regions when no BED rate exists) is
176+
replaced by avg_mut_rate before building probability_rates.
177+
"""
178+
result = _run_gv([(0, 200, None), (200, 400, 0.02)])
179+
assert isinstance(result, ContigVariants)
180+
181+
182+
def test_all_none_rates_no_crash():
183+
"""All None rates fall back entirely to avg_mut_rate."""
184+
result = _run_gv([(0, 200, None), (200, 400, None)])
185+
assert isinstance(result, ContigVariants)
186+
187+
188+
def test_three_regions_produces_variants():
189+
"""Multi-region run still generates at least the requested minimum mutations."""
190+
result = _run_gv([(0, 133, 0.01), (133, 266, 0.02), (266, 400, 0.05)],
191+
min_mutations=5)
192+
assert len(result.variant_locations) >= 1
193+
194+
195+
def test_multi_region_variant_positions_in_bounds():
196+
"""All variant positions fall within the reference after the fix."""
197+
ref = _make_ref(400)
198+
result = _run_gv([(0, 133, 0.01), (133, 266, 0.02), (266, 400, 0.05)],
199+
ref=ref, min_mutations=5)
200+
for loc in result.variant_locations:
201+
assert 0 <= loc < len(ref)
202+
203+
204+
def test_multi_region_reproducible_with_same_seed():
205+
"""Same seed produces identical variant locations with multiple regions."""
206+
regions = [(0, 133, 0.01), (133, 266, 0.02), (266, 400, 0.05)]
207+
r1 = _run_gv(regions, seed=7, min_mutations=5)
208+
r2 = _run_gv(regions, seed=7, min_mutations=5)
209+
assert r1.variant_locations == r2.variant_locations

0 commit comments

Comments
 (0)