@@ -1976,69 +1976,91 @@ def test_individuals_nodes(self, ts_fixture):
19761976
19771977 @pytest .mark .skipif (not _tskit .HAS_NUMPY_2 , reason = "Requires NumPy 2.0+" )
19781978 @pytest .mark .parametrize (
1979- "site_lengths" ,
1979+ "string_array" , ["sites_ancestral_state" , "mutations_derived_state" ]
1980+ )
1981+ @pytest .mark .parametrize (
1982+ "str_lengths" ,
19801983 ["none" , "all-0" , "all-1" , "all-2" , "mixed" , "very_long" , "unicode" ],
19811984 )
1982- def test_sites_ancestral_state (self , ts_fixture , site_lengths ):
1983- if site_lengths == "none" :
1985+ def test_string_arrays (self , ts_fixture , str_lengths , string_array ):
1986+ if str_lengths == "none" :
19841987 ts = tskit .TableCollection (1.0 ).tree_sequence ()
19851988 else :
1986- if site_lengths == "all-1" :
1989+ if str_lengths == "all-1" :
19871990 ts = ts_fixture
1988- assert {len (site .ancestral_state ) for site in ts .sites ()} == {1 }
1991+ if string_array == "sites_ancestral_state" :
1992+ assert ts .num_sites > 0
1993+ assert {len (site .ancestral_state ) for site in ts .sites ()} == {1 }
1994+ elif string_array == "mutations_derived_state" :
1995+ assert ts .num_mutations > 0
1996+ assert {len (mut .derived_state ) for mut in ts .mutations ()} == {1 }
19891997 else :
19901998 tables = ts_fixture .dump_tables ()
1991- sites = tables .sites .copy ()
1992- tables .sites .clear ()
1993-
1994- ancestral_state_map = {
1995- "all-0" : lambda i , site : "" ,
1996- "all-2" : lambda i , site : chr (ord ("A" ) + (i % 26 )) * 2 ,
1997- "mixed" : lambda i , site : chr (ord ("A" ) + (i % 26 )) * (i % 20 ),
1998- "very_long" : lambda i , site : "A" * 100_000_000 if i == 1 else "A" ,
1999- "unicode" : lambda i , site : "💩" * (i + 1 ),
1999+
2000+ str_map = {
2001+ "all-0" : lambda i , item : "" ,
2002+ "all-2" : lambda i , item : chr (ord ("A" ) + (i % 26 )) * 2 ,
2003+ "mixed" : lambda i , item : chr (ord ("A" ) + (i % 26 )) * (i % 20 ),
2004+ "very_long" : lambda i , item : "A" * 100_000_000 if i == 1 else "T" ,
2005+ "unicode" : lambda i , item : "🧬" * (i + 1 ),
20002006 }
20012007
2002- get_ancestral_state = ancestral_state_map [site_lengths ]
2003- for i , site in enumerate (sites ):
2004- tables .sites .append (
2005- site .replace (ancestral_state = get_ancestral_state (i , site ))
2006- )
2008+ if string_array == "sites_ancestral_state" :
2009+ sites = tables .sites .copy ()
2010+ tables .sites .clear ()
2011+ get_ancestral_state = str_map [str_lengths ]
2012+ for i , site in enumerate (sites ):
2013+ tables .sites .append (
2014+ site .replace (ancestral_state = get_ancestral_state (i , site ))
2015+ )
2016+ elif string_array == "mutations_derived_state" :
2017+ mutations = tables .mutations .copy ()
2018+ tables .mutations .clear ()
2019+ get_derived_state = str_map [str_lengths ]
2020+ for i , mutation in enumerate (mutations ):
2021+ tables .mutations .append (
2022+ mutation .replace (
2023+ derived_state = get_derived_state (i , mutation )
2024+ )
2025+ )
2026+
20072027 ts = tables .tree_sequence ()
20082028 ll_ts = ts .ll_tree_sequence
20092029
2010- a = ll_ts .sites_ancestral_state
2030+ a = getattr (ll_ts , string_array )
2031+
20112032 # Contents
2012- if site_lengths == "none" :
2033+ if str_lengths == "none" :
20132034 assert a .size == 0
20142035 else :
2015- for site in ts .sites ():
2016- assert a [site .id ] == site .ancestral_state
2036+ if string_array == "sites_ancestral_state" :
2037+ for site in ts .sites ():
2038+ assert a [site .id ] == site .ancestral_state
2039+ elif string_array == "mutations_derived_state" :
2040+ for mutation in ts .mutations ():
2041+ assert a [mutation .id ] == mutation .derived_state
20172042
20182043 # Read only
20192044 with pytest .raises (AttributeError , match = "not writable" ):
2020- ll_ts . sites_ancestral_state = None
2045+ setattr ( ll_ts , string_array , None )
20212046 with pytest .raises (AttributeError , match = "not writable" ):
2022- del ll_ts . sites_ancestral_state
2047+ delattr ( ll_ts , string_array )
20232048
20242049 with pytest .raises (ValueError , match = "assignment destination" ):
20252050 a [:] = 0
20262051 with pytest .raises (ValueError , match = "assignment destination" ):
20272052 a [0 ] = 0
2028- if site_lengths in [("all-1" ,)]:
2029- with pytest .raises (ValueError , match = "cannot set WRITEABLE" ):
2030- a .setflags (write = True )
20312053
20322054 # Properties
20332055 assert a .dtype == np .dtypes .StringDType ()
20342056 assert a .flags .aligned
20352057 assert a .flags .c_contiguous
2036- b = ll_ts . sites_ancestral_state
2058+ b = getattr ( ll_ts , string_array )
20372059 assert a is not b
20382060 assert np .all (a == b )
20392061
20402062 # Lifetime
2041- a1 = ll_ts . sites_ancestral_state
2063+ a1 = getattr ( ll_ts , string_array )
20422064 a2 = a1 .copy ()
20432065 assert a1 is not a2
20442066 del ll_ts
0 commit comments