@@ -2005,13 +2005,8 @@ def test_ld_matrix_multipop(self, stat_method_name):
20052005 stat_method (ss_sizes , ss , indexes , col_sites , row_sites , None , None , "node" )
20062006
20072007 def test_two_locus_count_stat (self ):
2008+ """Test two_locus_count_stat on biallelic data (no norm function)"""
20082009 ts = self .get_example_tree_sequence (10 )
2009- # Multiallelic test case to test norm function
2010- ts_multi = self .get_example_tree_sequence_multiallelic ()
2011- assert (ts .get_samples () == ts_multi .get_samples ()).all (), (
2012- "biallelic and multiallelic test case are expected "
2013- "to have the same sample nodes"
2014- )
20152010 ss = ts .get_samples () # sample sets
20162011 ss_sizes = np .array ([len (ss )], dtype = np .uint32 )
20172012 row_sites = np .arange (ts .get_num_sites (), dtype = np .int32 )
@@ -2029,37 +2024,33 @@ def stat_func(X, n):
20292024 pB = paB + pAB
20302025 return pAB - (pA * pB )
20312026
2032- def norm_func (X , n , nA , nB ):
2033- return X [ 0 ]. sum ( keepdims = True ) / n . sum ()
2027+ def norm_func (* _ ):
2028+ raise Exception # norm function will not be used
20342029
2035- method = ts .two_locus_count_stat # most tests on biallelic
2030+ method = ts .two_locus_count_stat
20362031 site_args = row_sites , col_sites , None , None , "site"
20372032 branch_args = None , None , row_pos , col_pos , "branch"
2033+ # happy path
20382034 a = method (ss_sizes , ss , stat_func , norm_func , 1 , True , * site_args )
2039- assert a .shape == (10 , 10 , 1 )
2035+ assert a .shape == (ts . get_num_sites (), ts . get_num_sites () , 1 )
20402036 a = method (ss_sizes , ss , stat_func , norm_func , 1 , True , * branch_args )
2041- assert a .shape == (2 , 2 , 1 )
2037+ assert a .shape == (ts .get_num_trees (), ts .get_num_trees (), 1 )
2038+ # happy path - sample sets as lists are also valid
20422039 site_list_args = row_sites_list , col_sites_list , None , None , "site"
20432040 branch_list_args = None , None , row_pos_list , col_pos_list , "branch"
2044-
2045- # happy path
20462041 a = method (ss_sizes , ss , stat_func , norm_func , 1 , True , * site_list_args )
2047- assert a .shape == (10 , 10 , 1 ) # ts has 10 sites
2042+ assert a .shape == (ts . get_num_sites (), ts . get_num_sites () , 1 )
20482043 a = method (ss_sizes , ss , stat_func , norm_func , 1 , True , * branch_list_args )
2049- assert a .shape == (2 , 2 , 1 ) # ts has 2 trees
2050- a = ts_multi .two_locus_count_stat (
2044+ assert a .shape == (ts .get_num_trees (), ts .get_num_trees (), 1 )
2045+ # happy path - default array filling
2046+ a = method (
20512047 ss_sizes , ss , stat_func , norm_func , 1 , True , None , None , None , None , "site"
20522048 )
2053- import platform
2054-
2055- if platform .system () == "Darwin" :
2056- assert a .shape == (54 , 54 , 1 ) # ts has 54 sites on macos?
2057- else :
2058- assert a .shape == (56 , 56 , 1 ) # ts has 56 sites
2059- a = ts_multi .two_locus_count_stat (
2049+ assert a .shape == (ts .get_num_sites (), ts .get_num_sites (), 1 )
2050+ a = method (
20602051 ss_sizes , ss , stat_func , norm_func , 1 , True , None , None , None , None , "branch"
20612052 )
2062- assert a .shape == (48 , 48 , 1 ) # ts has 48 trees
2053+ assert a .shape == (ts . get_num_trees (), ts . get_num_trees () , 1 )
20632054 # CPython API errors
20642055 with pytest .raises (ValueError , match = "Sum of sample_set_sizes" ):
20652056 bad_ss = np .array ([], dtype = np .int32 )
@@ -2136,50 +2127,17 @@ def norm_func(X, n, nA, nB):
21362127 method (ss_sizes , ss , lambda * _ : [1 , 2 ], norm_func , 1 , True , * site_args )
21372128 with pytest .raises (ValueError , match = "could not convert string to float" ):
21382129 method (ss_sizes , ss , lambda * _ : ["nonfloat" ], norm_func , 1 , True , * site_args )
2139- with pytest .raises (ValueError , match = "norm function.*must be 1D" ):
2140- ts_multi .two_locus_count_stat (
2141- ss_sizes , ss , stat_func , lambda * _ : 1 , 1 , True , * site_args
2142- )
2143- with pytest .raises (
2144- TypeError , match = "takes 1 positional argument but 2 were given"
2145- ):
2146- ts_multi .two_locus_count_stat (
2147- ss_sizes , ss , lambda _ : 1 , norm_func , 1 , True , * site_args
2148- )
2149- with pytest .raises (ValueError , match = "norm function.*length 2; must be 1" ):
2150- ts_multi .two_locus_count_stat (
2151- ss_sizes , ss , stat_func , lambda * _ : [1 , 2 ], 1 , True , * site_args
2152- )
2153- with pytest .raises (
2154- TypeError , match = "takes 1 positional argument but 4 were given"
2155- ):
2156- ts_multi .two_locus_count_stat (
2157- ss_sizes , ss , stat_func , lambda _ : [1 , 2 ], 1 , True , * site_args
2158- )
2159- with pytest .raises (ValueError , match = "could not convert string to float" ):
2160- ts_multi .two_locus_count_stat (
2161- ss_sizes , ss , stat_func , lambda * _ : ["nonfloat" ], 1 , True , * site_args
2162- )
2163- # Exceptions within stat_func and norm_func are correctly raised.
2130+ # Exceptions within stat_func are correctly raised.
21642131 for exception in [ValueError , TypeError ]:
21652132
21662133 def stat_func_except (* _ ):
21672134 raise exception ("test" ) # noqa: B023
21682135
2169- def norm_func_except (* _ ):
2170- raise exception ("test" ) # noqa: B023
2171-
2172- with pytest .raises (exception , match = "test" ):
2173- method (
2174- ss_sizes , ss , stat_func_except , norm_func , 1 , True , * site_list_args
2175- )
21762136 with pytest .raises (exception , match = "test" ):
2177- ts_multi .two_locus_count_stat (
2178- ss_sizes , ss , stat_func , norm_func_except , 1 , True , * site_list_args
2179- )
2137+ method (ss_sizes , ss , stat_func_except , norm_func , 1 , True , * site_args )
21802138 # C API errors
21812139 with pytest .raises (tskit .LibraryError , match = "TSK_ERR_BAD_RESULT_DIMS" ):
2182- method (ss_sizes , ss , stat_func , norm_func , 0 , True , * site_list_args )
2140+ method (ss_sizes , ss , stat_func , norm_func , 0 , True , * site_args )
21832141 with pytest .raises (tskit .LibraryError , match = "TSK_ERR_STAT_UNSORTED_SITES" ):
21842142 bad_sites = np .array ([1 , 0 , 2 ], dtype = np .int32 )
21852143 bad_site_args = bad_sites , col_sites , None , None , "site"
@@ -2229,6 +2187,56 @@ def norm_func_except(*_):
22292187 bad_branch_args = None , None , row_pos , bad_pos , "branch"
22302188 method (ss_sizes , ss , stat_func , norm_func , 1 , True , * bad_branch_args )
22312189
2190+ def test_two_locus_count_stat_multialleliic (self ):
2191+ """
2192+ Test two_locus_count_stat on multiallelic sites to test the behavior of
2193+ the norm function.
2194+ """
2195+ ts = self .get_example_tree_sequence_multiallelic ()
2196+
2197+ def stat_func (X , n ):
2198+ pAB , pAb , paB = X / n
2199+ pA = pAb + pAB
2200+ pB = paB + pAB
2201+ return pAB - (pA * pB )
2202+
2203+ def norm_func (X , n , nA , nB ):
2204+ return X [0 ].sum (keepdims = True ) / n .sum ()
2205+
2206+ ss = ts .get_samples () # sample sets
2207+ ss_sizes = np .array ([len (ss )], dtype = np .uint32 )
2208+ row_sites = np .arange (ts .get_num_sites (), dtype = np .int32 )
2209+ col_sites = row_sites
2210+ method = ts .two_locus_count_stat
2211+ site_args = row_sites , col_sites , None , None , "site"
2212+
2213+ # happy path
2214+ a = method (ss_sizes , ss , stat_func , norm_func , 1 , True , * site_args )
2215+ assert a .shape == (ts .get_num_sites (), ts .get_num_sites (), 1 )
2216+ # CPython API errors
2217+ with pytest .raises (ValueError , match = "norm function.*must be 1D" ):
2218+ method (ss_sizes , ss , stat_func , lambda * _ : 1 , 1 , True , * site_args )
2219+ with pytest .raises (
2220+ TypeError , match = "takes 1 positional argument but 2 were given"
2221+ ):
2222+ method (ss_sizes , ss , lambda _ : 1 , norm_func , 1 , True , * site_args )
2223+ with pytest .raises (ValueError , match = "norm function.*length 2; must be 1" ):
2224+ method (ss_sizes , ss , stat_func , lambda * _ : [1 , 2 ], 1 , True , * site_args )
2225+ with pytest .raises (
2226+ TypeError , match = "takes 1 positional argument but 4 were given"
2227+ ):
2228+ method (ss_sizes , ss , stat_func , lambda _ : [1 , 2 ], 1 , True , * site_args )
2229+ with pytest .raises (ValueError , match = "could not convert string to float" ):
2230+ method (ss_sizes , ss , stat_func , lambda * _ : ["nonfloat" ], 1 , True , * site_args )
2231+ # Exceptions within stat_func are correctly raised.
2232+ for exception in [ValueError , TypeError ]:
2233+
2234+ def norm_func_except (* _ ):
2235+ raise exception ("test" ) # noqa: B023
2236+
2237+ with pytest .raises (exception , match = "test" ):
2238+ method (ss_sizes , ss , stat_func , norm_func_except , 1 , True , * site_args )
2239+
22322240 def test_kc_distance_errors (self ):
22332241 ts1 = self .get_example_tree_sequence (10 )
22342242 with pytest .raises (TypeError ):
0 commit comments