Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

**Features**

- Add ``TreeSequence.sample_nodes_by_ploidy`` method to return the sample nodes
in a tree sequence, grouped by a ploidy value.
(:user:`benjeffery`, :pr:`3157`)

- Add ``TreeSequence.individuals_nodes`` attribute to return the nodes
associated with each individual as a numpy array.
(:user:`benjeffery`, :pr:`3153`)
Expand Down
84 changes: 84 additions & 0 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5584,3 +5584,87 @@ def test_mixed_sample_status(self):
expected = np.array([[0, 1]])
assert result.shape == (1, 2)
assert_array_equal(result, expected)


class TestSampleNodesByPloidy:
@pytest.mark.parametrize(
"n_samples,ploidy,expected",
[
(6, 2, np.array([[0, 1], [2, 3], [4, 5]])), # Basic diploid
(9, 3, np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])), # Triploid
(5, 1, np.array([[0], [1], [2], [3], [4]])), # Ploidy of 1
(4, 4, np.array([[0, 1, 2, 3]])), # Ploidy equals number of samples
],
)
def test_various_ploidy_scenarios(self, n_samples, ploidy, expected):
tables = tskit.TableCollection(sequence_length=100)
for _ in range(n_samples):
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
ts = tables.tree_sequence()

result = ts.sample_nodes_by_ploidy(ploidy)
expected_shape = (n_samples // ploidy, ploidy)
assert result.shape == expected_shape
assert_array_equal(result, expected)

def test_mixed_sample_status(self):
tables = tskit.TableCollection(sequence_length=100)
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
tables.nodes.add_row(flags=0, time=0)
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
tables.nodes.add_row(flags=0, time=0)
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
ts = tables.tree_sequence()

result = ts.sample_nodes_by_ploidy(2)
assert result.shape == (2, 2)
expected = np.array([[0, 2], [4, 5]])
assert_array_equal(result, expected)

def test_no_sample_nodes(self):
tables = tskit.TableCollection(sequence_length=100)
tables.nodes.add_row(flags=0, time=0)
tables.nodes.add_row(flags=0, time=0)
ts = tables.tree_sequence()

with pytest.raises(ValueError, match="No sample nodes in tree sequence"):
ts.sample_nodes_by_ploidy(2)

def test_not_multiple_of_ploidy(self):
tables = tskit.TableCollection(sequence_length=100)
for _ in range(5):
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
ts = tables.tree_sequence()

with pytest.raises(ValueError, match="not a multiple of ploidy"):
ts.sample_nodes_by_ploidy(2)

def test_with_existing_individuals(self):
tables = tskit.TableCollection(sequence_length=100)
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
tables.individuals.add_row(flags=0, location=(0, 0), metadata=b"")
# Add nodes with individual references but in a different order
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=1)
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=0)

ts = tables.tree_sequence()
result = ts.sample_nodes_by_ploidy(2)
expected = np.array([[0, 1], [2, 3]])
assert_array_equal(result, expected)
ind_nodes = ts.individuals_nodes
assert not np.array_equal(result, ind_nodes)

def test_different_node_flags(self):
tables = tskit.TableCollection(sequence_length=100)
OTHER_FLAG1 = 1 << 1
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
tables.nodes.add_row(flags=OTHER_FLAG1, time=0)
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE | OTHER_FLAG1, time=0)
tables.nodes.add_row()
ts = tables.tree_sequence()
result = ts.sample_nodes_by_ploidy(2)
assert result.shape == (1, 2)
assert_array_equal(result, np.array([[0, 2]]))
25 changes: 25 additions & 0 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -10495,6 +10495,31 @@ def ld_matrix(
mode=mode,
)

def sample_nodes_by_ploidy(self, ploidy):
"""
Returns an 2D array of node IDs, where each row has length `ploidy`.
This is useful when individuals are not defined in the tree sequence
so `TreeSequence.individuals_nodes` cannot be used. The samples are
placed in the array in the order which they are found in the node
table. The number of sample nodes must be a multiple of ploidy.

:param int ploidy: The number of samples per individual.
:return: A 2D array of node IDs, where each row has length `ploidy`.
:rtype: numpy.ndarray
"""
sample_node_ids = np.flatnonzero(self.nodes_flags & tskit.NODE_IS_SAMPLE)
num_samples = len(sample_node_ids)
if num_samples == 0:
raise ValueError("No sample nodes in tree sequence")
if num_samples % ploidy != 0:
raise ValueError(
f"Number of sample nodes {num_samples} is not a multiple "
f"of ploidy {ploidy}"
)
num_samples_per_individual = num_samples // ploidy
sample_node_ids = sample_node_ids.reshape((num_samples_per_individual, ploidy))
return sample_node_ids

############################################
#
# Deprecated APIs. These are either already unsupported, or will be unsupported in a
Expand Down
Loading