diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 7f1b29ce8d..a54894f4e1 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,6 +4,10 @@ **Features** +- Add ``TreeSequence.sample_nodes_by_ploidy`` method to return the sample nodes + in a tree sequence, grouped by a ploidy value. + (:user:`benjeffery`, :pr:`3157`) + - Add ``TreeSequence.individuals_nodes`` attribute to return the nodes associated with each individual as a numpy array. (:user:`benjeffery`, :pr:`3153`) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index f703408013..856987b383 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -5584,3 +5584,87 @@ def test_mixed_sample_status(self): expected = np.array([[0, 1]]) assert result.shape == (1, 2) assert_array_equal(result, expected) + + +class TestSampleNodesByPloidy: + @pytest.mark.parametrize( + "n_samples,ploidy,expected", + [ + (6, 2, np.array([[0, 1], [2, 3], [4, 5]])), # Basic diploid + (9, 3, np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])), # Triploid + (5, 1, np.array([[0], [1], [2], [3], [4]])), # Ploidy of 1 + (4, 4, np.array([[0, 1, 2, 3]])), # Ploidy equals number of samples + ], + ) + def test_various_ploidy_scenarios(self, n_samples, ploidy, expected): + tables = tskit.TableCollection(sequence_length=100) + for _ in range(n_samples): + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + ts = tables.tree_sequence() + + result = ts.sample_nodes_by_ploidy(ploidy) + expected_shape = (n_samples // ploidy, ploidy) + assert result.shape == expected_shape + assert_array_equal(result, expected) + + def test_mixed_sample_status(self): + tables = tskit.TableCollection(sequence_length=100) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + tables.nodes.add_row(flags=0, time=0) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + tables.nodes.add_row(flags=0, time=0) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + ts = tables.tree_sequence() + + result = ts.sample_nodes_by_ploidy(2) + assert result.shape == (2, 2) + expected = np.array([[0, 2], [4, 5]]) + assert_array_equal(result, expected) + + def test_no_sample_nodes(self): + tables = tskit.TableCollection(sequence_length=100) + tables.nodes.add_row(flags=0, time=0) + tables.nodes.add_row(flags=0, time=0) + ts = tables.tree_sequence() + + with pytest.raises(ValueError, match="No sample nodes in tree sequence"): + ts.sample_nodes_by_ploidy(2) + + def test_not_multiple_of_ploidy(self): + tables = tskit.TableCollection(sequence_length=100) + for _ in range(5): + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + ts = tables.tree_sequence() + + with pytest.raises(ValueError, match="not a multiple of ploidy"): + ts.sample_nodes_by_ploidy(2) + + def test_with_existing_individuals(self): + tables = tskit.TableCollection(sequence_length=100) + tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"") + tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"") + # Add nodes with individual references but in a different order + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0) + + ts = tables.tree_sequence() + result = ts.sample_nodes_by_ploidy(2) + expected = np.array([[0, 1], [2, 3]]) + assert_array_equal(result, expected) + ind_nodes = ts.individuals_nodes + assert not np.array_equal(result, ind_nodes) + + def test_different_node_flags(self): + tables = tskit.TableCollection(sequence_length=100) + OTHER_FLAG1 = 1 << 1 + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + tables.nodes.add_row(flags=OTHER_FLAG1, time=0) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE | OTHER_FLAG1, time=0) + tables.nodes.add_row() + ts = tables.tree_sequence() + result = ts.sample_nodes_by_ploidy(2) + assert result.shape == (1, 2) + assert_array_equal(result, np.array([[0, 2]])) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 16838c332e..3b1a410cb0 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10495,6 +10495,31 @@ def ld_matrix( mode=mode, ) + def sample_nodes_by_ploidy(self, ploidy): + """ + Returns an 2D array of node IDs, where each row has length `ploidy`. + This is useful when individuals are not defined in the tree sequence + so `TreeSequence.individuals_nodes` cannot be used. The samples are + placed in the array in the order which they are found in the node + table. The number of sample nodes must be a multiple of ploidy. + + :param int ploidy: The number of samples per individual. + :return: A 2D array of node IDs, where each row has length `ploidy`. + :rtype: numpy.ndarray + """ + sample_node_ids = np.flatnonzero(self.nodes_flags & tskit.NODE_IS_SAMPLE) + num_samples = len(sample_node_ids) + if num_samples == 0: + raise ValueError("No sample nodes in tree sequence") + if num_samples % ploidy != 0: + raise ValueError( + f"Number of sample nodes {num_samples} is not a multiple " + f"of ploidy {ploidy}" + ) + num_samples_per_individual = num_samples // ploidy + sample_node_ids = sample_node_ids.reshape((num_samples_per_individual, ploidy)) + return sample_node_ids + ############################################ # # Deprecated APIs. These are either already unsupported, or will be unsupported in a