@@ -7074,6 +7074,182 @@ def test_failure_with_migrations(self):
70747074 ts .trim ()
70757075
70767076
7077+ class TestShift :
7078+ """
7079+ Test the shift functionality
7080+ """
7081+
7082+ @pytest .mark .parametrize ("shift" , [- 0.5 , 0 , 0.5 ])
7083+ def test_shift (self , shift ):
7084+ ts = tskit .Tree .generate_comb (2 , span = 2 ).tree_sequence
7085+ tables = ts .dump_tables ()
7086+ tables .delete_intervals ([[0 , 1 ]], simplify = False )
7087+ tables .sites .add_row (1.5 , "A" )
7088+ ts = tables .tree_sequence ()
7089+ ts = ts .shift (shift )
7090+ assert ts .sequence_length == 2 + shift
7091+ assert np .min (ts .tables .edges .left ) == 1 + shift
7092+ assert np .max (ts .tables .edges .right ) == 2 + shift
7093+ assert np .all (ts .tables .sites .position == 1.5 + shift )
7094+ assert len (list (ts .trees ())) == ts .num_trees
7095+
7096+ def test_sequence_length (self ):
7097+ ts = tskit .Tree .generate_comb (2 ).tree_sequence
7098+ ts = ts .shift (1 , sequence_length = 3 )
7099+ assert ts .sequence_length == 3
7100+ ts = ts .shift (- 1 , sequence_length = 1 )
7101+ assert ts .sequence_length == 1
7102+
7103+ def test_empty (self ):
7104+ empty_ts = tskit .TableCollection (1.0 ).tree_sequence ()
7105+ empty_ts = empty_ts .shift (1 )
7106+ assert empty_ts .sequence_length == 2
7107+ empty_ts = empty_ts .shift (- 1.5 )
7108+ assert empty_ts .sequence_length == 0.5
7109+ assert empty_ts .num_nodes == 0
7110+
7111+ def test_migrations (self ):
7112+ tables = tskit .Tree .generate_comb (2 , span = 2 ).tree_sequence .dump_tables ()
7113+ tables .populations .add_row ()
7114+ tables .migrations .add_row (0 , 1 , 0 , 0 , 0 , 0 )
7115+ ts = tables .tree_sequence ().shift (10 )
7116+ assert np .all (ts .tables .migrations .left == 10 )
7117+ assert np .all (ts .tables .migrations .right == 11 )
7118+
7119+ def test_provenance (self ):
7120+ ts = tskit .Tree .generate_comb (2 ).tree_sequence
7121+ ts = ts .shift (1 , record_provenance = False )
7122+ params = json .loads (ts .provenance (- 1 ).record )["parameters" ]
7123+ assert params ["command" ] != "shift"
7124+ ts = ts .shift (1 , sequence_length = 9 )
7125+ params = json .loads (ts .provenance (- 1 ).record )["parameters" ]
7126+ assert params ["command" ] == "shift"
7127+ assert params ["value" ] == 1
7128+ assert params ["sequence_length" ] == 9
7129+
7130+ def test_too_negative (self ):
7131+ ts = tskit .Tree .generate_comb (2 ).tree_sequence
7132+ with pytest .raises (tskit .LibraryError , match = "TSK_ERR_BAD_SEQUENCE_LENGTH" ):
7133+ ts .shift (- 1 )
7134+
7135+ def test_bad_seq_len (self ):
7136+ ts = tskit .Tree .generate_comb (2 ).tree_sequence
7137+ with pytest .raises (
7138+ tskit .LibraryError , match = "TSK_ERR_RIGHT_GREATER_SEQ_LENGTH"
7139+ ):
7140+ ts .shift (1 , sequence_length = 1 )
7141+
7142+
7143+ class TestConcatenate :
7144+ def test_simple (self ):
7145+ ts1 = tskit .Tree .generate_comb (5 , span = 2 ).tree_sequence
7146+ ts2 = tskit .Tree .generate_balanced (5 , arity = 3 , span = 3 ).tree_sequence
7147+ assert ts1 .num_samples == ts2 .num_samples
7148+ assert ts1 .num_nodes != ts2 .num_nodes
7149+ joint_ts = ts1 .concatenate (ts2 )
7150+ assert joint_ts .num_nodes == ts1 .num_nodes + ts2 .num_nodes - 5
7151+ assert joint_ts .sequence_length == ts1 .sequence_length + ts2 .sequence_length
7152+ assert joint_ts .num_samples == ts1 .num_samples
7153+ ts3 = joint_ts .delete_intervals ([[2 , 5 ]]).rtrim ()
7154+ # Have to simplify here, to remove the redundant nodes
7155+ assert ts3 .equals (ts1 .simplify (), ignore_provenance = True )
7156+ ts4 = joint_ts .delete_intervals ([[0 , 2 ]]).ltrim ()
7157+ assert ts4 .equals (ts2 .simplify (), ignore_provenance = True )
7158+
7159+ def test_multiple (self ):
7160+ np .random .seed (42 )
7161+ ts3 = [
7162+ tskit .Tree .generate_comb (5 , span = 2 ).tree_sequence ,
7163+ tskit .Tree .generate_balanced (5 , arity = 3 , span = 3 ).tree_sequence ,
7164+ tskit .Tree .generate_star (5 , span = 5 ).tree_sequence ,
7165+ ]
7166+ for i in range (1 , len (ts3 )):
7167+ # shuffle the sample nodes so they don't have the same IDs
7168+ ts3 [i ] = ts3 [i ].subset (np .random .permutation (ts3 [i ].num_nodes ))
7169+ assert not np .all (ts3 [0 ].samples () == ts3 [1 ].samples ())
7170+ assert not np .all (ts3 [0 ].samples () == ts3 [2 ].samples ())
7171+ assert not np .all (ts3 [1 ].samples () == ts3 [2 ].samples ())
7172+ ts = ts3 [0 ].concatenate (* ts3 [1 :])
7173+ assert ts .sequence_length == sum ([t .sequence_length for t in ts3 ])
7174+ assert ts .num_nodes - ts .num_samples == sum (
7175+ [t .num_nodes - t .num_samples for t in ts3 ]
7176+ )
7177+ assert np .all (ts .samples () == ts3 [0 ].samples ())
7178+
7179+ def test_empty (self ):
7180+ empty_ts = tskit .TableCollection (10 ).tree_sequence ()
7181+ ts = empty_ts .concatenate (empty_ts , empty_ts , empty_ts )
7182+ assert ts .num_nodes == 0
7183+ assert ts .sequence_length == 40
7184+
7185+ def test_samples_at_end (self ):
7186+ ts1 = tskit .Tree .generate_comb (5 , span = 2 ).tree_sequence
7187+ ts2 = tskit .Tree .generate_balanced (5 , arity = 3 , span = 3 ).tree_sequence
7188+ # reverse the node order
7189+ ts1 = ts1 .subset (np .arange (ts1 .num_nodes )[::- 1 ])
7190+ assert ts1 .num_samples == ts2 .num_samples
7191+ assert np .all (ts1 .samples () != ts2 .samples ())
7192+ joint_ts = ts1 .concatenate (ts2 )
7193+ assert joint_ts .num_samples == ts1 .num_samples
7194+ assert np .all (joint_ts .samples () == ts1 .samples ())
7195+
7196+ def test_internal_samples (self ):
7197+ tables = tskit .Tree .generate_comb (4 , span = 2 ).tree_sequence .dump_tables ()
7198+ nodes_flags = tables .nodes .flags
7199+ nodes_flags [:] = tskit .NODE_IS_SAMPLE
7200+ nodes_flags [- 1 ] = 0 # Only root is not a sample
7201+ tables .nodes .flags = nodes_flags
7202+ ts = tables .tree_sequence ()
7203+ joint_ts = ts .concatenate (ts )
7204+ assert joint_ts .num_samples == ts .num_samples
7205+ assert joint_ts .num_nodes == ts .num_nodes + 1
7206+ assert joint_ts .sequence_length == ts .sequence_length * 2
7207+
7208+ def test_some_shared_samples (self ):
7209+ ts1 = tskit .Tree .generate_comb (4 , span = 2 ).tree_sequence
7210+ ts2 = tskit .Tree .generate_balanced (8 , arity = 3 , span = 3 ).tree_sequence
7211+ shared = np .full (ts2 .num_nodes , tskit .NULL )
7212+ shared [0 ] = 1
7213+ shared [1 ] = 0
7214+ joint_ts = ts1 .concatenate (ts2 , node_mappings = [shared ])
7215+ assert joint_ts .sequence_length == ts1 .sequence_length + ts2 .sequence_length
7216+ assert joint_ts .num_samples == ts1 .num_samples + ts2 .num_samples - 2
7217+ assert joint_ts .num_nodes == ts1 .num_nodes + ts2 .num_nodes - 2
7218+
7219+ def test_provenance (self ):
7220+ ts = tskit .Tree .generate_comb (2 ).tree_sequence
7221+ ts = ts .concatenate (ts , record_provenance = False )
7222+ params = json .loads (ts .provenance (- 1 ).record )["parameters" ]
7223+ assert params ["command" ] != "concatenate"
7224+
7225+ ts = ts .concatenate (ts )
7226+ params = json .loads (ts .provenance (- 1 ).record )["parameters" ]
7227+ assert params ["command" ] == "concatenate"
7228+
7229+ def test_unequal_samples (self ):
7230+ ts1 = tskit .Tree .generate_comb (5 , span = 2 ).tree_sequence
7231+ ts2 = tskit .Tree .generate_balanced (4 , arity = 3 , span = 3 ).tree_sequence
7232+ with pytest .raises (ValueError , match = "must have the same number of samples" ):
7233+ ts1 .concatenate (ts2 )
7234+
7235+ @pytest .mark .skip (
7236+ reason = "union bug: https://github.com/tskit-dev/tskit/issues/3168"
7237+ )
7238+ def test_duplicate_ts (self ):
7239+ ts1 = tskit .Tree .generate_comb (3 , span = 4 ).tree_sequence
7240+ ts = ts1 .keep_intervals ([[0 , 1 ]]).trim () # a quarter of the original
7241+ nm = np .arange (ts .num_nodes ) # all nodes identical
7242+ ts2 = ts .concatenate (ts , ts , ts , node_mappings = [nm ] * 3 , add_populations = False )
7243+ ts2 = ts2 .simplify () # squash the edges
7244+ assert ts1 .equals (ts2 , ignore_provenance = True )
7245+
7246+ def test_node_mappings_bad_len (self ):
7247+ ts = tskit .Tree .generate_comb (3 , span = 2 ).tree_sequence
7248+ nm = np .arange (ts .num_nodes )
7249+ with pytest .raises (ValueError , match = "same number of node_mappings" ):
7250+ ts .concatenate (ts , ts , ts , node_mappings = [nm , nm ])
7251+
7252+
70777253class TestMissingData :
70787254 """
70797255 Test various aspects of missing data functionality
0 commit comments