Skip to content

Commit 593161a

Browse files
committed
Update tests according to Peters's feedback
*Python tests* Overhaul python testing of the general stat functions. Remove the dependence on the example tree sequences, opting instead to simulate a couple of examples directly. Use these simulated trees in test fixtures, scoped at the module level. This streamlines the test parameterization a lot. Use the single stat site names from the summary function definitions. *CPython tests* Add a multiallelic tree sequence to test normalisation function validation and errors. Remove one more occurrence of `np.expand_dims`. *trees.c* Remove the unnecessary branch in tsk_treeseq_two_locus_count_general_stat, improving the code coverage. *trees.py* Default normalisation function can be None, applying default at runtime. Simplifies calling code and is more in line with the rest of the API.
1 parent bdc8092 commit 593161a

4 files changed

Lines changed: 165 additions & 124 deletions

File tree

c/tskit/trees.c

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3439,21 +3439,22 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self,
34393439
ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets,
34403440
sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows,
34413441
row_sites, out_cols, col_sites, options, result);
3442-
} else if (stat_branch) {
3443-
ret = check_positions(
3444-
row_positions, out_rows, tsk_treeseq_get_sequence_length(self));
3445-
if (ret != 0) {
3446-
goto out;
3447-
}
3448-
ret = check_positions(
3449-
col_positions, out_cols, tsk_treeseq_get_sequence_length(self));
3450-
if (ret != 0) {
3451-
goto out;
3452-
}
3453-
ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets,
3454-
sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows,
3455-
row_positions, out_cols, col_positions, options, result);
3442+
goto out;
3443+
}
3444+
tsk_bug_assert(stat_branch);
3445+
ret = check_positions(
3446+
row_positions, out_rows, tsk_treeseq_get_sequence_length(self));
3447+
if (ret != 0) {
3448+
goto out;
3449+
}
3450+
ret = check_positions(
3451+
col_positions, out_cols, tsk_treeseq_get_sequence_length(self));
3452+
if (ret != 0) {
3453+
goto out;
34563454
}
3455+
ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets,
3456+
sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows,
3457+
row_positions, out_cols, col_positions, options, result);
34573458
out:
34583459
return ret;
34593460
}

python/tests/test_ld_matrix.py

Lines changed: 63 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,55 +2538,49 @@ def D2_ij_unbiased(X, n):
25382538
)
25392539

25402540

2541-
@pytest.mark.parametrize(
2542-
"ts,stat",
2543-
[
2544-
(
2545-
ts := tsutil.get_sim_example(
2546-
sample_size=100,
2547-
sequence_length=32,
2548-
recombination_rate=0.5,
2549-
mutation_rate=0.1,
2550-
seed=123,
2551-
),
2552-
"D",
2541+
@pytest.fixture(scope="module")
2542+
def ts_100_samp_with_sites_fixture():
2543+
ts = tsutil.get_sim_example(
2544+
sample_size=100,
2545+
sequence_length=32,
2546+
recombination_rate=0.5,
2547+
mutation_rate=0.1,
2548+
seed=123,
2549+
)
2550+
assert ts.num_sites > 0, "sites are required"
2551+
assert ts.num_samples == 100, "100 samples are required"
2552+
return ts
2553+
2554+
2555+
@pytest.fixture(scope="module")
2556+
def ts_multiallelic_fixture():
2557+
ts = msprime.sim_mutations(
2558+
msprime.sim_ancestry(
2559+
2, recombination_rate=0.1, sequence_length=100, random_seed=123
25532560
),
2554-
(ts, "D2"),
2555-
(ts, "r2"),
2556-
(ts, "r"),
2557-
(ts, "D_prime"),
2558-
(ts, "Dz"),
2559-
(ts, "pi2"),
2560-
(ts, "D2_unbiased"),
2561-
(ts, "Dz_unbiased"),
2562-
(ts, "pi2_unbiased"),
2563-
],
2564-
)
2565-
def test_general_two_locus_site_stat(ts, stat):
2561+
rate=0.1,
2562+
random_seed=123,
2563+
)
2564+
# Need at least 4 samples to test unbiased statistics
2565+
assert ts.num_samples >= 4, "At least 4 samples required"
2566+
assert max({len(s.mutations) for s in ts.sites()}) > 2, (
2567+
"At least one multiallelic site required"
2568+
)
2569+
return ts
2570+
2571+
2572+
@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys())
2573+
def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture):
2574+
ts = ts_100_samp_with_sites_fixture
25662575
sample_sets = [ts.samples()[0:50], ts.samples()[50:100]]
25672576
ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2)
25682577
ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat)
25692578
np.testing.assert_array_almost_equal(ldg, ld)
25702579

25712580

2572-
@pytest.mark.parametrize(
2573-
"ts,stat",
2574-
[
2575-
(
2576-
ts := tsutil.get_sim_example(
2577-
sample_size=100,
2578-
sequence_length=32,
2579-
recombination_rate=0.5,
2580-
mutation_rate=0.1,
2581-
seed=123,
2582-
),
2583-
"r2_ij",
2584-
),
2585-
(ts, "D2_ij"),
2586-
(ts, "D2_ij_unbiased"),
2587-
],
2588-
)
2589-
def test_general_two_locus_two_way_site_stat(ts, stat):
2581+
@pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"])
2582+
def test_general_two_locus_two_way_site_stat(stat, ts_100_samp_with_sites_fixture):
2583+
ts = ts_100_samp_with_sites_fixture
25902584
sample_sets = [ts.samples()[0:50], ts.samples()[50:100]]
25912585
ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1)
25922586
ld = ts.ld_matrix(
@@ -2595,62 +2589,31 @@ def test_general_two_locus_two_way_site_stat(ts, stat):
25952589
np.testing.assert_array_almost_equal(ldg, ld)
25962590

25972591

2598-
@pytest.mark.parametrize(
2599-
"stat",
2600-
[
2601-
"D",
2602-
"D2",
2603-
"r2",
2604-
"r",
2605-
"D_prime",
2606-
"Dz",
2607-
"pi2",
2608-
"D2_unbiased",
2609-
"Dz_unbiased",
2610-
"pi2_unbiased",
2611-
],
2612-
)
2613-
def test_general_one_way_two_locus_stat_multiallelic(stat):
2614-
ts = tsutil.all_fields_ts()
2615-
func = getattr(GeneralStatFuncs, stat)
2616-
if stat == "r2":
2617-
result = ts.two_locus_count_stat(
2618-
[ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n
2619-
)
2620-
elif stat in {"D", "r", "D_prime"}:
2621-
result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True)
2622-
else:
2623-
# default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]`
2624-
result = ts.two_locus_count_stat([ts.samples()], func, 1)
2625-
np.testing.assert_array_almost_equal(ts.ld_matrix(stat=stat), result)
2626-
2627-
2628-
@pytest.mark.parametrize(
2629-
"stat",
2630-
[
2631-
"r2_ij",
2632-
"D2_ij",
2633-
"D2_ij_unbiased",
2634-
],
2635-
)
2636-
def test_general_two_way_two_locus_stat_multiallelic(stat):
2637-
ts = tsutil.all_fields_ts()
2638-
func = getattr(GeneralStatFuncs, stat)
2639-
if stat == "r2_ij":
2640-
result = ts.two_locus_count_stat(
2641-
[ts.samples(), ts.samples()],
2642-
func,
2643-
1,
2644-
lambda X, n, nA, nB: X[0].sum(keepdims=True) / n.sum(),
2645-
)
2646-
else:
2647-
# default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]`
2648-
result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1)
2649-
np.testing.assert_array_almost_equal(
2650-
ts.ld_matrix(
2651-
stat=stat.replace("_ij", ""),
2652-
indexes=(0, 1),
2653-
sample_sets=[ts.samples(), ts.samples()],
2654-
),
2655-
result,
2592+
@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys())
2593+
def test_general_one_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixture):
2594+
ts = ts_multiallelic_fixture
2595+
general_func = getattr(GeneralStatFuncs, stat)
2596+
norm_func = (lambda X, n, nA, nB: X[0] / n) if stat == "r2" else None
2597+
polarised = POLARIZATION[SUMMARY_FUNCS[stat]]
2598+
ldg = ts.two_locus_count_stat(
2599+
[ts.samples()], general_func, 1, norm_f=norm_func, polarised=polarised
2600+
)
2601+
ld = ts.ld_matrix(stat=stat)
2602+
np.testing.assert_array_almost_equal(ld, ldg)
2603+
2604+
2605+
@pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"])
2606+
def test_general_two_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixture):
2607+
ts = ts_multiallelic_fixture
2608+
general_func = getattr(GeneralStatFuncs, stat)
2609+
norm_func = (
2610+
(lambda X, n, nA, nB: X[0].sum(keepdims=True) / n.sum())
2611+
if stat == "r2_ij"
2612+
else None
2613+
)
2614+
sample_sets = [ts.samples(), ts.samples()]
2615+
ldg = ts.two_locus_count_stat(sample_sets, general_func, 1, norm_f=norm_func)
2616+
ld = ts.ld_matrix(
2617+
stat=stat.replace("_ij", ""), indexes=(0, 1), sample_sets=sample_sets
26562618
)
2619+
np.testing.assert_array_almost_equal(ld, ldg)

python/tests/test_python_c.py

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,23 @@ def get_example_migration_tree_sequence(self):
138138
)
139139
return ts.ll_tree_sequence
140140

141+
def get_example_tree_sequence_multiallelic(self, sample_size=10):
142+
ts = msprime.sim_mutations(
143+
msprime.sim_ancestry(
144+
sample_size,
145+
recombination_rate=0.1,
146+
sequence_length=100,
147+
ploidy=1,
148+
random_seed=123,
149+
),
150+
rate=0.1,
151+
random_seed=123,
152+
)
153+
assert max({len(s.mutations) for s in ts.sites()}) > 2, (
154+
"At least one multiallelic site required"
155+
)
156+
return ts.ll_tree_sequence
157+
141158
def verify_iterator(self, iterator):
142159
"""
143160
Checks that the specified non-empty iterator implements the
@@ -1989,6 +2006,12 @@ def test_ld_matrix_multipop(self, stat_method_name):
19892006

19902007
def test_two_locus_count_stat(self):
19912008
ts = self.get_example_tree_sequence(10)
2009+
# Multiallelic test case to test norm function
2010+
ts_multi = self.get_example_tree_sequence_multiallelic()
2011+
assert (ts.get_samples() == ts_multi.get_samples()).all(), (
2012+
"biallelic and multiallelic test case are expected "
2013+
"to have the same sample nodes"
2014+
)
19922015
ss = ts.get_samples() # sample sets
19932016
ss_sizes = np.array([len(ss)], dtype=np.uint32)
19942017
row_sites = np.arange(ts.get_num_sites(), dtype=np.int32)
@@ -2007,10 +2030,9 @@ def stat_func(X, n):
20072030
return pAB - (pA * pB)
20082031

20092032
def norm_func(X, n, nA, nB):
2010-
return np.expand_dims(X[0].sum() / n.sum(), axis=0)
2011-
2012-
method = ts.two_locus_count_stat
2033+
return X[0].sum(keepdims=True) / n.sum()
20132034

2035+
method = ts.two_locus_count_stat # most tests on biallelic
20142036
site_args = row_sites, col_sites, None, None, "site"
20152037
branch_args = None, None, row_pos, col_pos, "branch"
20162038
a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args)
@@ -2019,10 +2041,20 @@ def norm_func(X, n, nA, nB):
20192041
assert a.shape == (2, 2, 1)
20202042
site_list_args = row_sites_list, col_sites_list, None, None, "site"
20212043
branch_list_args = None, None, row_pos_list, col_pos_list, "branch"
2044+
2045+
# happy path
20222046
a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_list_args)
2023-
assert a.shape == (10, 10, 1)
2047+
assert a.shape == (10, 10, 1) # ts has 10 sites
20242048
a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args)
2025-
assert a.shape == (2, 2, 1)
2049+
assert a.shape == (2, 2, 1) # ts has 2 trees
2050+
a = ts_multi.two_locus_count_stat(
2051+
ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "site"
2052+
)
2053+
assert a.shape == (56, 56, 1) # ts has 56 sites
2054+
a = ts_multi.two_locus_count_stat(
2055+
ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "branch"
2056+
)
2057+
assert a.shape == (48, 48, 1) # ts has 48 trees
20262058
# CPython API errors
20272059
with pytest.raises(ValueError, match="Sum of sample_set_sizes"):
20282060
bad_ss = np.array([], dtype=np.int32)
@@ -2094,10 +2126,55 @@ def norm_func(X, n, nA, nB):
20942126
with pytest.raises(TypeError, match="norm_func must be callable"):
20952127
method(ss_sizes, ss, stat_func, "uncallable", 1, True, *site_args)
20962128
with pytest.raises(ValueError, match="summary function.*must be 1D"):
2097-
method(ss_sizes, ss, lambda a, b: 1, norm_func, 1, True, *site_args)
2098-
with pytest.raises(ValueError, match="length 2; must be 1"):
2099-
method(ss_sizes, ss, lambda a, b: [1, 2], norm_func, 1, True, *site_args)
2129+
method(ss_sizes, ss, lambda *_: 1, norm_func, 1, True, *site_args)
2130+
with pytest.raises(ValueError, match="summary function.*length 2; must be 1"):
2131+
method(ss_sizes, ss, lambda *_: [1, 2], norm_func, 1, True, *site_args)
2132+
with pytest.raises(ValueError, match="could not convert string to float"):
2133+
method(ss_sizes, ss, lambda *_: ["nonfloat"], norm_func, 1, True, *site_args)
2134+
with pytest.raises(ValueError, match="norm function.*must be 1D"):
2135+
ts_multi.two_locus_count_stat(
2136+
ss_sizes, ss, stat_func, lambda *_: 1, 1, True, *site_args
2137+
)
2138+
with pytest.raises(
2139+
TypeError, match="takes 1 positional argument but 2 were given"
2140+
):
2141+
ts_multi.two_locus_count_stat(
2142+
ss_sizes, ss, lambda _: 1, norm_func, 1, True, *site_args
2143+
)
2144+
with pytest.raises(ValueError, match="norm function.*length 2; must be 1"):
2145+
ts_multi.two_locus_count_stat(
2146+
ss_sizes, ss, stat_func, lambda *_: [1, 2], 1, True, *site_args
2147+
)
2148+
with pytest.raises(
2149+
TypeError, match="takes 1 positional argument but 4 were given"
2150+
):
2151+
ts_multi.two_locus_count_stat(
2152+
ss_sizes, ss, stat_func, lambda _: [1, 2], 1, True, *site_args
2153+
)
2154+
with pytest.raises(ValueError, match="could not convert string to float"):
2155+
ts_multi.two_locus_count_stat(
2156+
ss_sizes, ss, stat_func, lambda *_: ["nonfloat"], 1, True, *site_args
2157+
)
2158+
# Exceptions within stat_func and norm_func are correctly raised.
2159+
for exception in [ValueError, TypeError]:
2160+
2161+
def stat_func_except(*_):
2162+
raise exception("test")
2163+
2164+
def norm_func_except(*_):
2165+
raise exception("test")
2166+
2167+
with pytest.raises(exception, match="test"):
2168+
method(
2169+
ss_sizes, ss, stat_func_except, norm_func, 1, True, *site_list_args
2170+
)
2171+
with pytest.raises(exception, match="test"):
2172+
ts_multi.two_locus_count_stat(
2173+
ss_sizes, ss, stat_func, norm_func_except, 1, True, *site_list_args
2174+
)
21002175
# C API errors
2176+
with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_RESULT_DIMS"):
2177+
method(ss_sizes, ss, stat_func, norm_func, 0, True, *site_list_args)
21012178
with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"):
21022179
bad_sites = np.array([1, 0, 2], dtype=np.int32)
21032180
bad_site_args = bad_sites, col_sites, None, None, "site"

python/tskit/trees.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10936,7 +10936,7 @@ def two_locus_count_stat(
1093610936
sample_sets,
1093710937
f,
1093810938
result_dim,
10939-
norm_f=lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,],
10939+
norm_f=None,
1094010940
polarised=False,
1094110941
sites=None,
1094210942
positions=None,
@@ -10949,7 +10949,7 @@ def two_locus_count_stat(
1094910949
sample_set_sizes,
1095010950
sample_sets,
1095110951
f,
10952-
norm_f,
10952+
norm_f or (lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]),
1095310953
result_dim,
1095410954
polarised,
1095510955
row_sites,

0 commit comments

Comments
 (0)