@@ -5584,3 +5584,101 @@ def test_mixed_sample_status(self):
55845584 expected = np .array ([[0 , 1 ]])
55855585 assert result .shape == (1 , 2 )
55865586 assert_array_equal (result , expected )
5587+
5588+
5589+ class TestSampleNodesByPloidy :
5590+ @pytest .mark .parametrize (
5591+ "n_samples,ploidy,expected" ,
5592+ [
5593+ (6 , 2 , np .array ([[0 , 1 ], [2 , 3 ], [4 , 5 ]])), # Basic diploid
5594+ (9 , 3 , np .array ([[0 , 1 , 2 ], [3 , 4 , 5 ], [6 , 7 , 8 ]])), # Triploid
5595+ (5 , 1 , np .array ([[0 ], [1 ], [2 ], [3 ], [4 ]])), # Ploidy of 1
5596+ (4 , 4 , np .array ([[0 , 1 , 2 , 3 ]])), # Ploidy equals number of samples
5597+ ],
5598+ )
5599+ def test_various_ploidy_scenarios (self , n_samples , ploidy , expected ):
5600+ tables = tskit .TableCollection (sequence_length = 100 )
5601+ # Add sample nodes
5602+ for _ in range (n_samples ):
5603+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5604+ ts = tables .tree_sequence ()
5605+
5606+ result = ts .sample_nodes_by_ploidy (ploidy )
5607+ expected_shape = (n_samples // ploidy , ploidy )
5608+ assert result .shape == expected_shape
5609+ assert_array_equal (result , expected )
5610+
5611+ def test_mixed_sample_status (self ):
5612+ tables = tskit .TableCollection (sequence_length = 100 )
5613+ # Add 4 sample nodes and 2 non-sample nodes
5614+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5615+ tables .nodes .add_row (flags = 0 , time = 0 )
5616+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5617+ tables .nodes .add_row (flags = 0 , time = 0 )
5618+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5619+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5620+ ts = tables .tree_sequence ()
5621+
5622+ result = ts .sample_nodes_by_ploidy (2 )
5623+ assert result .shape == (2 , 2 )
5624+ # Should only include nodes with sample flag
5625+ expected = np .array ([[0 , 2 ], [4 , 5 ]])
5626+ assert_array_equal (result , expected )
5627+
5628+ def test_no_sample_nodes (self ):
5629+ tables = tskit .TableCollection (sequence_length = 100 )
5630+ tables .nodes .add_row (flags = 0 , time = 0 )
5631+ tables .nodes .add_row (flags = 0 , time = 0 )
5632+ ts = tables .tree_sequence ()
5633+
5634+ with pytest .raises (ValueError , match = "No sample nodes in tree sequence" ):
5635+ ts .sample_nodes_by_ploidy (2 )
5636+
5637+ def test_not_multiple_of_ploidy (self ):
5638+ tables = tskit .TableCollection (sequence_length = 100 )
5639+ # Add 5 sample nodes
5640+ for _ in range (5 ):
5641+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5642+ ts = tables .tree_sequence ()
5643+
5644+ with pytest .raises (ValueError , match = "not a multiple of ploidy" ):
5645+ ts .sample_nodes_by_ploidy (2 )
5646+
5647+ def test_with_existing_individuals (self ):
5648+ tables = tskit .TableCollection (sequence_length = 100 )
5649+ tables .individuals .add_row (flags = 0 , location = (0 , 0 ), metadata = b"" )
5650+ tables .individuals .add_row (flags = 0 , location = (0 , 0 ), metadata = b"" )
5651+ # Add nodes with individual references but in a different order
5652+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 1 )
5653+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 0 )
5654+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 1 )
5655+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 , individual = 0 )
5656+
5657+ ts = tables .tree_sequence ()
5658+ result = ts .sample_nodes_by_ploidy (2 )
5659+ expected = np .array ([[0 , 1 ], [2 , 3 ]])
5660+ assert_array_equal (result , expected )
5661+ ind_nodes = ts .individuals_nodes
5662+ assert not np .array_equal (result , ind_nodes )
5663+
5664+ def test_different_node_flags (self ):
5665+ """
5666+ Test that only nodes with the NODE_IS_SAMPLE flag are considered
5667+ samples, regardless of other flags that might be set.
5668+ """
5669+ tables = tskit .TableCollection (sequence_length = 100 )
5670+
5671+ # Define some other flags for testing
5672+ OTHER_FLAG1 = 1 << 1
5673+
5674+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE , time = 0 )
5675+ tables .nodes .add_row (flags = OTHER_FLAG1 , time = 0 )
5676+ tables .nodes .add_row (flags = tskit .NODE_IS_SAMPLE | OTHER_FLAG1 , time = 0 )
5677+ tables .nodes .add_row ()
5678+
5679+ ts = tables .tree_sequence ()
5680+
5681+ # We should have 2 sample nodes (0, 2, 3)
5682+ result = ts .sample_nodes_by_ploidy (2 )
5683+ assert result .shape == (1 , 2 )
5684+ assert_array_equal (result , np .array ([[0 , 2 ]]))
0 commit comments