Skip to content

Commit 4c04ff3

Browse files
committed
Clean up python C tests
Use the number of sites and trees reported by the tree sequence instead of hard coded values. This has the benefit of being more readable, communicating intent (review comment from Peter). Split the multiallelic and biallelic test cases, they're getting messy. Now I can explicitly assert that the norm_func is not run for biallelic sites and for branch stats. Also gets rid of awkward assertions about sample sets.
1 parent bd0a1a5 commit 4c04ff3

1 file changed

Lines changed: 68 additions & 60 deletions

File tree

python/tests/test_python_c.py

Lines changed: 68 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,13 +2005,8 @@ def test_ld_matrix_multipop(self, stat_method_name):
20052005
stat_method(ss_sizes, ss, indexes, col_sites, row_sites, None, None, "node")
20062006

20072007
def test_two_locus_count_stat(self):
2008+
"""Test two_locus_count_stat on biallelic data (no norm function)"""
20082009
ts = self.get_example_tree_sequence(10)
2009-
# Multiallelic test case to test norm function
2010-
ts_multi = self.get_example_tree_sequence_multiallelic()
2011-
assert (ts.get_samples() == ts_multi.get_samples()).all(), (
2012-
"biallelic and multiallelic test case are expected "
2013-
"to have the same sample nodes"
2014-
)
20152010
ss = ts.get_samples() # sample sets
20162011
ss_sizes = np.array([len(ss)], dtype=np.uint32)
20172012
row_sites = np.arange(ts.get_num_sites(), dtype=np.int32)
@@ -2029,37 +2024,33 @@ def stat_func(X, n):
20292024
pB = paB + pAB
20302025
return pAB - (pA * pB)
20312026

2032-
def norm_func(X, n, nA, nB):
2033-
return X[0].sum(keepdims=True) / n.sum()
2027+
def norm_func(*_):
2028+
raise Exception # norm function will not be used
20342029

2035-
method = ts.two_locus_count_stat # most tests on biallelic
2030+
method = ts.two_locus_count_stat
20362031
site_args = row_sites, col_sites, None, None, "site"
20372032
branch_args = None, None, row_pos, col_pos, "branch"
2033+
# happy path
20382034
a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args)
2039-
assert a.shape == (10, 10, 1)
2035+
assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1)
20402036
a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_args)
2041-
assert a.shape == (2, 2, 1)
2037+
assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1)
2038+
# happy path - sample sets as lists are also valid
20422039
site_list_args = row_sites_list, col_sites_list, None, None, "site"
20432040
branch_list_args = None, None, row_pos_list, col_pos_list, "branch"
2044-
2045-
# happy path
20462041
a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_list_args)
2047-
assert a.shape == (10, 10, 1) # ts has 10 sites
2042+
assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1)
20482043
a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args)
2049-
assert a.shape == (2, 2, 1) # ts has 2 trees
2050-
a = ts_multi.two_locus_count_stat(
2044+
assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1)
2045+
# happy path - default array filling
2046+
a = method(
20512047
ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "site"
20522048
)
2053-
import platform
2054-
2055-
if platform.system() == "Darwin":
2056-
assert a.shape == (54, 54, 1) # ts has 54 sites on macos?
2057-
else:
2058-
assert a.shape == (56, 56, 1) # ts has 56 sites
2059-
a = ts_multi.two_locus_count_stat(
2049+
assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1)
2050+
a = method(
20602051
ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "branch"
20612052
)
2062-
assert a.shape == (48, 48, 1) # ts has 48 trees
2053+
assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1)
20632054
# CPython API errors
20642055
with pytest.raises(ValueError, match="Sum of sample_set_sizes"):
20652056
bad_ss = np.array([], dtype=np.int32)
@@ -2136,50 +2127,17 @@ def norm_func(X, n, nA, nB):
21362127
method(ss_sizes, ss, lambda *_: [1, 2], norm_func, 1, True, *site_args)
21372128
with pytest.raises(ValueError, match="could not convert string to float"):
21382129
method(ss_sizes, ss, lambda *_: ["nonfloat"], norm_func, 1, True, *site_args)
2139-
with pytest.raises(ValueError, match="norm function.*must be 1D"):
2140-
ts_multi.two_locus_count_stat(
2141-
ss_sizes, ss, stat_func, lambda *_: 1, 1, True, *site_args
2142-
)
2143-
with pytest.raises(
2144-
TypeError, match="takes 1 positional argument but 2 were given"
2145-
):
2146-
ts_multi.two_locus_count_stat(
2147-
ss_sizes, ss, lambda _: 1, norm_func, 1, True, *site_args
2148-
)
2149-
with pytest.raises(ValueError, match="norm function.*length 2; must be 1"):
2150-
ts_multi.two_locus_count_stat(
2151-
ss_sizes, ss, stat_func, lambda *_: [1, 2], 1, True, *site_args
2152-
)
2153-
with pytest.raises(
2154-
TypeError, match="takes 1 positional argument but 4 were given"
2155-
):
2156-
ts_multi.two_locus_count_stat(
2157-
ss_sizes, ss, stat_func, lambda _: [1, 2], 1, True, *site_args
2158-
)
2159-
with pytest.raises(ValueError, match="could not convert string to float"):
2160-
ts_multi.two_locus_count_stat(
2161-
ss_sizes, ss, stat_func, lambda *_: ["nonfloat"], 1, True, *site_args
2162-
)
2163-
# Exceptions within stat_func and norm_func are correctly raised.
2130+
# Exceptions within stat_func are correctly raised.
21642131
for exception in [ValueError, TypeError]:
21652132

21662133
def stat_func_except(*_):
21672134
raise exception("test") # noqa: B023
21682135

2169-
def norm_func_except(*_):
2170-
raise exception("test") # noqa: B023
2171-
2172-
with pytest.raises(exception, match="test"):
2173-
method(
2174-
ss_sizes, ss, stat_func_except, norm_func, 1, True, *site_list_args
2175-
)
21762136
with pytest.raises(exception, match="test"):
2177-
ts_multi.two_locus_count_stat(
2178-
ss_sizes, ss, stat_func, norm_func_except, 1, True, *site_list_args
2179-
)
2137+
method(ss_sizes, ss, stat_func_except, norm_func, 1, True, *site_args)
21802138
# C API errors
21812139
with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_RESULT_DIMS"):
2182-
method(ss_sizes, ss, stat_func, norm_func, 0, True, *site_list_args)
2140+
method(ss_sizes, ss, stat_func, norm_func, 0, True, *site_args)
21832141
with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"):
21842142
bad_sites = np.array([1, 0, 2], dtype=np.int32)
21852143
bad_site_args = bad_sites, col_sites, None, None, "site"
@@ -2229,6 +2187,56 @@ def norm_func_except(*_):
22292187
bad_branch_args = None, None, row_pos, bad_pos, "branch"
22302188
method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args)
22312189

2190+
def test_two_locus_count_stat_multialleliic(self):
2191+
"""
2192+
Test two_locus_count_stat on multiallelic sites to test the behavior of
2193+
the norm function.
2194+
"""
2195+
ts = self.get_example_tree_sequence_multiallelic()
2196+
2197+
def stat_func(X, n):
2198+
pAB, pAb, paB = X / n
2199+
pA = pAb + pAB
2200+
pB = paB + pAB
2201+
return pAB - (pA * pB)
2202+
2203+
def norm_func(X, n, nA, nB):
2204+
return X[0].sum(keepdims=True) / n.sum()
2205+
2206+
ss = ts.get_samples() # sample sets
2207+
ss_sizes = np.array([len(ss)], dtype=np.uint32)
2208+
row_sites = np.arange(ts.get_num_sites(), dtype=np.int32)
2209+
col_sites = row_sites
2210+
method = ts.two_locus_count_stat
2211+
site_args = row_sites, col_sites, None, None, "site"
2212+
2213+
# happy path
2214+
a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args)
2215+
assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1)
2216+
# CPython API errors
2217+
with pytest.raises(ValueError, match="norm function.*must be 1D"):
2218+
method(ss_sizes, ss, stat_func, lambda *_: 1, 1, True, *site_args)
2219+
with pytest.raises(
2220+
TypeError, match="takes 1 positional argument but 2 were given"
2221+
):
2222+
method(ss_sizes, ss, lambda _: 1, norm_func, 1, True, *site_args)
2223+
with pytest.raises(ValueError, match="norm function.*length 2; must be 1"):
2224+
method(ss_sizes, ss, stat_func, lambda *_: [1, 2], 1, True, *site_args)
2225+
with pytest.raises(
2226+
TypeError, match="takes 1 positional argument but 4 were given"
2227+
):
2228+
method(ss_sizes, ss, stat_func, lambda _: [1, 2], 1, True, *site_args)
2229+
with pytest.raises(ValueError, match="could not convert string to float"):
2230+
method(ss_sizes, ss, stat_func, lambda *_: ["nonfloat"], 1, True, *site_args)
2231+
# Exceptions within stat_func are correctly raised.
2232+
for exception in [ValueError, TypeError]:
2233+
2234+
def norm_func_except(*_):
2235+
raise exception("test") # noqa: B023
2236+
2237+
with pytest.raises(exception, match="test"):
2238+
method(ss_sizes, ss, stat_func, norm_func_except, 1, True, *site_args)
2239+
22322240
def test_kc_distance_errors(self):
22332241
ts1 = self.get_example_tree_sequence(10)
22342242
with pytest.raises(TypeError):

0 commit comments

Comments
 (0)