2727import io
2828import itertools
2929import json
30+ import platform
3031import random
3132import sys
3233import unittest
4344import tskit .provenance as provenance
4445
4546
47+ IS_WINDOWS = platform .system () == "Windows"
48+
49+
4650def simple_keep_intervals (tables , intervals , simplify = True , record_provenance = True ):
4751 """
4852 Simple Python implementation of keep_intervals.
@@ -7141,18 +7145,223 @@ def test_bad_seq_len(self):
71417145 ts .shift (1 , sequence_length = 1 )
71427146
71437147
7148+ class TestMerge :
7149+ def test_empty (self ):
7150+ ts = tskit .TableCollection (2 ).tree_sequence ()
7151+ merged_ts = ts .merge (ts , node_mapping = [])
7152+ assert merged_ts .num_nodes == 0
7153+ assert merged_ts .num_edges == 0
7154+ assert merged_ts .sequence_length == 2
7155+
7156+ def test_overlay (self ):
7157+ ts1 = tskit .Tree .generate_balanced (4 , span = 2 ).tree_sequence
7158+ tables = tskit .Tree .generate_comb (4 , span = 2 ).tree_sequence .dump_tables ()
7159+ tables .populations .add_row ()
7160+ tables .nodes [5 ] = tables .nodes [5 ].replace (
7161+ flags = tskit .NODE_IS_SAMPLE , population = 0
7162+ )
7163+ ts2 = tables .tree_sequence ()
7164+ ts_merge = ts1 .merge (ts2 , node_mapping = np .full (ts1 .num_nodes , tskit .NULL ))
7165+ assert ts_merge .sequence_length == ts1 .sequence_length
7166+ assert ts_merge .num_samples == ts1 .num_samples + ts2 .num_samples
7167+ assert ts_merge .num_nodes == ts1 .num_nodes + ts2 .num_nodes
7168+ assert ts_merge .num_edges == ts1 .num_edges + ts2 .num_edges
7169+ assert ts_merge .num_trees == 1
7170+ assert ts_merge .num_populations == 1
7171+ assert ts_merge .first ().num_roots == 2
7172+
7173+ def test_split_and_merge (self ):
7174+ # Cut up a single tree into alternating edges and mutations, then merge
7175+ ts = tskit .Tree .generate_comb (4 , span = 10 ).tree_sequence
7176+ ts = msprime .sim_mutations (ts , rate = 0.1 , random_seed = 1 )
7177+ mut_counts = np .bincount (ts .mutations_site , minlength = ts .num_sites )
7178+ assert min (mut_counts ) == 1
7179+ assert max (mut_counts ) > 1
7180+ tables1 = ts .dump_tables ()
7181+ tables1 .mutations .clear ()
7182+ tables2 = tables1 .copy ()
7183+ i = 0
7184+ for s in ts .sites ():
7185+ for m in s .mutations :
7186+ i += 1
7187+ if i % 2 :
7188+ tables1 .mutations .append (m .replace (parent = tskit .NULL ))
7189+ else :
7190+ tables2 .mutations .append (m .replace (parent = tskit .NULL ))
7191+ tables1 .simplify ()
7192+ tables2 .simplify ()
7193+ assert tables1 .sites .num_rows != ts .num_sites
7194+ tables1 .edges .clear ()
7195+ tables2 .edges .clear ()
7196+ for e in ts .edges ():
7197+ if e .id % 2 :
7198+ tables1 .edges .append (e )
7199+ else :
7200+ tables2 .edges .append (e )
7201+ ts1 = tables1 .tree_sequence ()
7202+ ts2 = tables2 .tree_sequence ()
7203+ new_ts = ts1 .merge (ts2 , node_mapping = np .arange (ts .num_nodes )).simplify ()
7204+ assert new_ts .equals (ts , ignore_provenance = True )
7205+
7206+ def test_multi_tree (self ):
7207+ ts = msprime .sim_ancestry (
7208+ 2 , sequence_length = 4 , recombination_rate = 1 , random_seed = 1
7209+ )
7210+ ts = msprime .sim_mutations (ts , rate = 1 , random_seed = 1 )
7211+ assert ts .num_trees > 3
7212+ assert ts .num_mutations > 4
7213+ ts1 = ts .keep_intervals ([[0 , 1.5 ]], simplify = False )
7214+ ts2 = ts .keep_intervals ([[1.5 , 4 ]], simplify = False )
7215+ new_ts = ts1 .merge (
7216+ ts2 , node_mapping = np .arange (ts .num_nodes ), add_populations = False
7217+ )
7218+ assert new_ts .num_trees == ts .num_trees + 1
7219+ new_ts = new_ts .simplify ()
7220+ new_ts .equals (ts , ignore_provenance = True )
7221+
7222+ def test_new_individuals (self ):
7223+ ts1 = msprime .sim_ancestry (2 , sequence_length = 1 , random_seed = 1 )
7224+ ts2 = msprime .sim_ancestry (2 , sequence_length = 1 , random_seed = 2 )
7225+ tables = ts2 .dump_tables ()
7226+ tables .edges .clear ()
7227+ ts2 = tables .tree_sequence ()
7228+ node_map = np .full (ts2 .num_nodes , tskit .NULL )
7229+ node_map [0 :2 ] = [0 , 1 ] # map first two nodes to themselves
7230+ ts_merged = ts1 .merge (ts2 , node_mapping = node_map )
7231+ assert ts_merged .num_nodes == ts1 .num_nodes + ts2 .num_nodes - 2
7232+ assert ts1 .num_individuals == 2
7233+ assert ts_merged .num_individuals == 3
7234+
7235+ def test_popcheck (self ):
7236+ tables = tskit .TableCollection (1 )
7237+ p1 = tables .populations .add_row (b"foo" )
7238+ p2 = tables .populations .add_row (b"bar" )
7239+ tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE , population = p1 )
7240+ tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE , population = p2 )
7241+ ts1 = tables .tree_sequence ()
7242+ tables .populations [0 ] = tables .populations [0 ].replace (metadata = b"baz" )
7243+ ts2 = tables .tree_sequence ()
7244+ with pytest .raises (ValueError , match = "Non-matching populations" ):
7245+ ts1 .merge (ts2 , node_mapping = [0 , 1 ])
7246+ ts1 .merge (ts2 , node_mapping = [0 , 1 ], check_populations = False )
7247+ # Check with add_populations=False
7248+ ts1 .merge (ts2 , node_mapping = [- 1 , 1 ]) # only merge the last one
7249+ with pytest .raises (ValueError , match = "Non-matching populations" ):
7250+ ts1 .merge (ts2 , node_mapping = [- 1 , 1 ], add_populations = False )
7251+
7252+ with pytest .raises (ValueError , match = "Non-matching populations" ):
7253+ ts1 .simplify ([0 ]).merge (ts2 , node_mapping = [- 1 , 1 ])
7254+
7255+ def test_isolated_mutations (self ):
7256+ tables = tskit .TableCollection (1 )
7257+ u = tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE )
7258+ s = tables .sites .add_row (0.5 , "A" )
7259+ tables .mutations .add_row (s , u , derived_state = "T" , time = 1 , metadata = b"xxx" )
7260+ ts1 = tables .tree_sequence ()
7261+ tables .mutations [0 ] = tables .mutations [0 ].replace (time = 0.5 , metadata = b"yyy" )
7262+ ts2 = tables .tree_sequence ()
7263+ ts_merge = ts1 .merge (ts2 , node_mapping = [0 ])
7264+ assert ts_merge .num_sites == 1
7265+ assert ts_merge .num_mutations == 2
7266+ assert ts_merge .mutation (0 ).time == 1
7267+ assert ts_merge .mutation (0 ).parent == tskit .NULL
7268+ assert ts_merge .mutation (0 ).metadata == b"xxx"
7269+ assert ts_merge .mutation (1 ).time == 0.5
7270+ assert ts_merge .mutation (1 ).parent == 0
7271+ assert ts_merge .mutation (1 ).metadata == b"yyy"
7272+
7273+ def test_identity (self ):
7274+ tables = tskit .TableCollection (1 )
7275+ tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE )
7276+ ts = tables .tree_sequence ()
7277+ ts_merge = ts .merge (ts , node_mapping = [0 ])
7278+ assert ts .equals (ts_merge , ignore_provenance = True )
7279+
7280+ @pytest .mark .skipif (IS_WINDOWS , reason = "Msprime gives different result on Windows" )
7281+ def test_migrations (self ):
7282+ pop_configs = [msprime .PopulationConfiguration (3 ) for _ in range (2 )]
7283+ migration_matrix = [[0 , 0.001 ], [0.001 , 0 ]]
7284+ ts = msprime .simulate (
7285+ population_configurations = pop_configs ,
7286+ migration_matrix = migration_matrix ,
7287+ record_migrations = True ,
7288+ recombination_rate = 2 ,
7289+ random_seed = 42 , # pick a seed that gives min(migrations.left) > 0
7290+ end_time = 100 ,
7291+ )
7292+ # No migration_table.squash() function exists, so we just try to cut on the
7293+ # LHS of all the migrations
7294+ assert ts .num_migrations > 0
7295+ assert ts .migrations_left .min () > 0
7296+ cutpoint = ts .migrations_left .min ()
7297+ ts1 = ts .keep_intervals ([[0 , cutpoint ]], simplify = False )
7298+ ts2 = ts .keep_intervals ([[cutpoint , ts .sequence_length ]], simplify = False )
7299+ ts_new = ts1 .merge (ts2 , node_mapping = np .arange (ts .num_nodes ))
7300+ tables = ts_new .dump_tables ()
7301+ tables .edges .squash ()
7302+ tables .sort ()
7303+ ts_new = tables .tree_sequence ()
7304+ ts .tables .assert_equals (ts_new .tables , ignore_provenance = True )
7305+
7306+ def test_provenance (self ):
7307+ tables = tskit .TableCollection (1 )
7308+ tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE )
7309+ ts = tables .tree_sequence ()
7310+ ts_merge = ts .merge (ts , node_mapping = [0 ], record_provenance = False )
7311+ assert ts_merge .num_provenances == ts .num_provenances
7312+ ts_merge = ts .merge (ts , node_mapping = [0 ])
7313+ assert ts_merge .num_provenances == ts .num_provenances + 1
7314+ prov = json .loads (ts_merge .provenance (- 1 ).record )
7315+ assert prov ["parameters" ]["command" ] == "merge"
7316+ assert prov ["parameters" ]["node_mapping" ] == [0 ]
7317+ assert prov ["parameters" ]["add_populations" ] is True
7318+ assert prov ["parameters" ]["check_populations" ] is True
7319+
7320+ def test_bad_sequence_length (self ):
7321+ ts1 = tskit .TableCollection (1 ).tree_sequence ()
7322+ ts2 = tskit .TableCollection (2 ).tree_sequence ()
7323+ with pytest .raises (ValueError , match = "sequence length" ):
7324+ ts1 .merge (ts2 , node_mapping = [])
7325+
7326+ def test_bad_node_mapping (self ):
7327+ ts = tskit .Tree .generate_comb (5 ).tree_sequence
7328+ with pytest .raises (ValueError , match = "node_mapping" ):
7329+ ts .merge (ts , node_mapping = [0 , 1 , 2 ])
7330+
7331+ def test_bad_populations (self ):
7332+ tables = tskit .TableCollection (1 )
7333+ tables = tskit .TableCollection (1 )
7334+ p1 = tables .populations .add_row ()
7335+ p2 = tables .populations .add_row ()
7336+ tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE , population = p1 )
7337+ tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE , population = p1 )
7338+ tables .nodes .add_row (time = 0 , flags = tskit .NODE_IS_SAMPLE , population = p2 )
7339+ ts2 = tables .tree_sequence ()
7340+ ts1 = ts2 .simplify ([0 , 1 ])
7341+ assert ts1 .num_populations == 1
7342+ assert ts2 .num_populations == 2
7343+ ts2 .merge (ts1 , [0 , - 1 ], check_populations = False , add_populations = False )
7344+ with pytest .raises (ValueError , match = "population not present" ):
7345+ ts1 .merge (ts2 , [0 , - 1 , - 1 ], check_populations = False , add_populations = False )
7346+
7347+
71447348class TestConcatenate :
71457349 def test_simple (self ):
71467350 ts1 = tskit .Tree .generate_comb (5 , span = 2 ).tree_sequence
7351+ ts1 = msprime .sim_mutations (ts1 , rate = 1 , random_seed = 1 )
71477352 ts2 = tskit .Tree .generate_balanced (5 , arity = 3 , span = 3 ).tree_sequence
7353+ ts2 = msprime .sim_mutations (ts2 , rate = 1 , random_seed = 1 )
71487354 assert ts1 .num_samples == ts2 .num_samples
71497355 assert ts1 .num_nodes != ts2 .num_nodes
71507356 joint_ts = ts1 .concatenate (ts2 )
71517357 assert joint_ts .num_nodes == ts1 .num_nodes + ts2 .num_nodes - 5
71527358 assert joint_ts .sequence_length == ts1 .sequence_length + ts2 .sequence_length
71537359 assert joint_ts .num_samples == ts1 .num_samples
7360+ assert joint_ts .num_sites == ts1 .num_sites + ts2 .num_sites
7361+ assert joint_ts .num_mutations == ts1 .num_mutations + ts2 .num_mutations
71547362 ts3 = joint_ts .delete_intervals ([[2 , 5 ]]).rtrim ()
71557363 # Have to simplify here, to remove the redundant nodes
7364+ ts3 .tables .assert_equals (ts1 .tables , ignore_provenance = True )
71567365 assert ts3 .equals (ts1 .simplify (), ignore_provenance = True )
71577366 ts4 = joint_ts .delete_intervals ([[0 , 2 ]]).ltrim ()
71587367 assert ts4 .equals (ts2 .simplify (), ignore_provenance = True )
@@ -7183,6 +7392,13 @@ def test_empty(self):
71837392 assert ts .num_nodes == 0
71847393 assert ts .sequence_length == 40
71857394
7395+ def test_check_populations (self ):
7396+ ts = msprime .sim_ancestry (2 )
7397+ ts1 = ts .concatenate (ts , ts , check_populations = True )
7398+ assert ts1 .num_populations == 1
7399+ ts2 = ts .concatenate (ts , ts , add_populations = True , check_populations = True )
7400+ assert ts2 .num_populations == 3
7401+
71867402 def test_samples_at_end (self ):
71877403 ts1 = tskit .Tree .generate_comb (5 , span = 2 ).tree_sequence
71887404 ts2 = tskit .Tree .generate_balanced (5 , arity = 3 , span = 3 ).tree_sequence
@@ -7200,22 +7416,58 @@ def test_internal_samples(self):
72007416 nodes_flags [:] = tskit .NODE_IS_SAMPLE
72017417 nodes_flags [- 1 ] = 0 # Only root is not a sample
72027418 tables .nodes .flags = nodes_flags
7203- ts = tables .tree_sequence ()
7419+ ts = msprime .sim_mutations (tables .tree_sequence (), rate = 0.5 , random_seed = 1 )
7420+ assert ts .num_mutations > 0
7421+ assert ts .num_mutations > ts .num_sites
72047422 joint_ts = ts .concatenate (ts )
72057423 assert joint_ts .num_samples == ts .num_samples
72067424 assert joint_ts .num_nodes == ts .num_nodes + 1
7425+ assert joint_ts .num_mutations == ts .num_mutations * 2
7426+ assert joint_ts .num_sites == ts .num_sites * 2
72077427 assert joint_ts .sequence_length == ts .sequence_length * 2
72087428
72097429 def test_some_shared_samples (self ):
7210- ts1 = tskit .Tree .generate_comb (4 , span = 2 ).tree_sequence
7211- ts2 = tskit .Tree .generate_balanced (8 , arity = 3 , span = 3 ).tree_sequence
7212- shared = np .full (ts2 .num_nodes , tskit .NULL )
7213- shared [0 ] = 1
7214- shared [1 ] = 0
7215- joint_ts = ts1 .concatenate (ts2 , node_mappings = [shared ])
7216- assert joint_ts .sequence_length == ts1 .sequence_length + ts2 .sequence_length
7217- assert joint_ts .num_samples == ts1 .num_samples + ts2 .num_samples - 2
7218- assert joint_ts .num_nodes == ts1 .num_nodes + ts2 .num_nodes - 2
7430+ tables = tskit .Tree .generate_comb (5 ).tree_sequence .dump_tables ()
7431+ tables .nodes [5 ] = tables .nodes [5 ].replace (flags = tskit .NODE_IS_SAMPLE )
7432+ ts1 = tables .tree_sequence ()
7433+ tables = tskit .Tree .generate_balanced (5 ).tree_sequence .dump_tables ()
7434+ tables .nodes [5 ] = tables .nodes [5 ].replace (flags = tskit .NODE_IS_SAMPLE )
7435+ ts2 = tables .tree_sequence ()
7436+ assert ts1 .num_samples == ts2 .num_samples
7437+ joint_ts = ts1 .concatenate (ts2 )
7438+ assert joint_ts .num_samples == ts1 .num_samples
7439+ assert joint_ts .num_edges == ts1 .num_edges + ts2 .num_edges
7440+ for tree in joint_ts .trees ():
7441+ assert tree .num_roots == 1
7442+
7443+ @pytest .mark .parametrize ("simplify" , [True , False ])
7444+ def test_wf_sim (self , simplify ):
7445+ # Test that we can split & concat a wf_sim ts, which has internal samples
7446+ tables = wf .wf_sim (
7447+ 6 ,
7448+ 5 ,
7449+ seed = 3 ,
7450+ deep_history = True ,
7451+ initial_generation_samples = True ,
7452+ num_loci = 10 ,
7453+ )
7454+ tables .sort ()
7455+ tables .simplify ()
7456+ ts = msprime .mutate (tables .tree_sequence (), rate = 0.05 , random_seed = 234 )
7457+ assert ts .num_trees > 2
7458+ assert len (np .unique (ts .nodes_time [ts .samples ()])) > 1
7459+ ts1 = ts .keep_intervals ([[0 , 4.5 ]], simplify = False ).trim ()
7460+ ts2 = ts .keep_intervals ([[4.5 , ts .sequence_length ]], simplify = False ).trim ()
7461+ if simplify :
7462+ ts1 = ts1 .simplify (filter_nodes = False )
7463+ ts2 , node_map = ts2 .simplify (map_nodes = True )
7464+ node_mapping = np .zeros_like (node_map , shape = ts2 .num_nodes )
7465+ kept = node_map != tskit .NULL
7466+ node_mapping [node_map [kept ]] = np .arange (len (node_map ))[kept ]
7467+ else :
7468+ node_mapping = np .arange (ts .num_nodes )
7469+ ts_new = ts1 .concatenate (ts2 , node_mappings = [node_mapping ]).simplify ()
7470+ ts_new .tables .assert_equals (ts .tables , ignore_provenance = True )
72197471
72207472 def test_provenance (self ):
72217473 ts = tskit .Tree .generate_comb (2 ).tree_sequence
@@ -7233,9 +7485,12 @@ def test_unequal_samples(self):
72337485 with pytest .raises (ValueError , match = "must have the same number of samples" ):
72347486 ts1 .concatenate (ts2 )
72357487
7236- @pytest .mark .skip (
7237- reason = "union bug: https://github.com/tskit-dev/tskit/issues/3168"
7238- )
7488+ def test_different_sample_numbers (self ):
7489+ ts1 = tskit .Tree .generate_comb (5 , span = 2 ).tree_sequence
7490+ ts2 = tskit .Tree .generate_balanced (4 , arity = 3 , span = 3 ).tree_sequence
7491+ with pytest .raises (ValueError , match = "must have the same number of samples" ):
7492+ ts1 .concatenate (ts2 )
7493+
72397494 def test_duplicate_ts (self ):
72407495 ts1 = tskit .Tree .generate_comb (3 , span = 4 ).tree_sequence
72417496 ts = ts1 .keep_intervals ([[0 , 1 ]]).trim () # a quarter of the original
0 commit comments