From 942e3838bdae867feacca550ad22538c8ca5d39b Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Thu, 15 May 2025 15:03:57 +0100 Subject: [PATCH 1/3] Use plain `python` in doc Makefile --- docs/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Makefile b/docs/Makefile index 53e2cf89ef..7068c5f766 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -2,7 +2,7 @@ # Need to set PYTHONPATH so that we pick up the local tskit PYPATH=$(shell pwd)/../python/ TSK_VERSION:=$(shell PYTHONPATH=${PYPATH} \ - python3 -c 'import tskit; print(tskit.__version__.split("+")[0])') + python -c 'import tskit; print(tskit.__version__.split("+")[0])') BUILDDIR = _build DOXYGEN_XML=doxygen/xml From 589a037ddd7a368ebac1ce530ae314bc64192132 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Fri, 16 May 2025 02:21:41 +0100 Subject: [PATCH 2/3] Add a link to the subset docs from `union` --- python/tskit/trees.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 3b1a410cb0..e0100b7e7a 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -7304,9 +7304,9 @@ def union( ``other`` to nodes in ``self``. :param bool check_shared_equality: If True, the shared portions of the tree sequences will be checked for equality. It does so by - subsetting both ``self`` and ``other`` on the equivalent nodes - specified in ``node_mapping``, and then checking for equality of - the subsets. + running :meth:`TreeSequence.subset` on both ``self`` and ``other`` + for the equivalent nodes specified in ``node_mapping``, and then + checking for equality of the subsets. :param bool add_populations: If True, nodes new to ``self`` will be assigned new population IDs. :param bool record_provenance: Whether to record a provenance entry From a39e503f41a69dfde70718a4a81b72ff563b5e1a Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Thu, 15 May 2025 15:11:18 +0100 Subject: [PATCH 3/3] Add `shift` and `concatenate` functions Fixes #3164 --- docs/python-api.md | 3 + python/CHANGELOG.rst | 5 + python/tests/test_topology.py | 176 ++++++++++++++++++++++++++++++++++ python/tskit/tables.py | 34 +++++++ python/tskit/trees.py | 100 +++++++++++++++++++ 5 files changed, 318 insertions(+) diff --git a/docs/python-api.md b/docs/python-api.md index 386db89486..1ccf28d0ff 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -262,10 +262,12 @@ which perform the same actions but modify the {class}`TableCollection` in place. TreeSequence.simplify TreeSequence.subset TreeSequence.union + TreeSequence.concatenate TreeSequence.keep_intervals TreeSequence.delete_intervals TreeSequence.delete_sites TreeSequence.trim + TreeSequence.shift TreeSequence.split_edges TreeSequence.decapitate TreeSequence.extend_haplotypes @@ -750,6 +752,7 @@ a functional way, returning a new tree sequence while leaving the original uncha TableCollection.keep_intervals TableCollection.delete_sites TableCollection.trim + TableCollection.shift TableCollection.union TableCollection.delete_older ``` diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index efa72bc343..a869bcd09d 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -12,6 +12,11 @@ associated with each individual as a numpy array. (:user:`benjeffery`, :pr:`3153`) +- Add ``shift`` method to both ``TableCollection`` and ``TreeSequence`` classes + allowing the coordinate system to be shifted, and ``TreeSequence.concatenate`` + so a set of tree sequence can be added to the right of an existing one. + (:user:`hyanwong`, :pr:`3165`, :issue:`3164`) + **Fixes** diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 47be14e4bd..4864e6bfa1 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -7074,6 +7074,182 @@ def test_failure_with_migrations(self): ts.trim() +class TestShift: + """ + Test the shift functionality + """ + + @pytest.mark.parametrize("shift", [-0.5, 0, 0.5]) + def test_shift(self, shift): + ts = tskit.Tree.generate_comb(2, span=2).tree_sequence + tables = ts.dump_tables() + tables.delete_intervals([[0, 1]], simplify=False) + tables.sites.add_row(1.5, "A") + ts = tables.tree_sequence() + ts = ts.shift(shift) + assert ts.sequence_length == 2 + shift + assert np.min(ts.tables.edges.left) == 1 + shift + assert np.max(ts.tables.edges.right) == 2 + shift + assert np.all(ts.tables.sites.position == 1.5 + shift) + assert len(list(ts.trees())) == ts.num_trees + + def test_sequence_length(self): + ts = tskit.Tree.generate_comb(2).tree_sequence + ts = ts.shift(1, sequence_length=3) + assert ts.sequence_length == 3 + ts = ts.shift(-1, sequence_length=1) + assert ts.sequence_length == 1 + + def test_empty(self): + empty_ts = tskit.TableCollection(1.0).tree_sequence() + empty_ts = empty_ts.shift(1) + assert empty_ts.sequence_length == 2 + empty_ts = empty_ts.shift(-1.5) + assert empty_ts.sequence_length == 0.5 + assert empty_ts.num_nodes == 0 + + def test_migrations(self): + tables = tskit.Tree.generate_comb(2, span=2).tree_sequence.dump_tables() + tables.populations.add_row() + tables.migrations.add_row(0, 1, 0, 0, 0, 0) + ts = tables.tree_sequence().shift(10) + assert np.all(ts.tables.migrations.left == 10) + assert np.all(ts.tables.migrations.right == 11) + + def test_provenance(self): + ts = tskit.Tree.generate_comb(2).tree_sequence + ts = ts.shift(1, record_provenance=False) + params = json.loads(ts.provenance(-1).record)["parameters"] + assert params["command"] != "shift" + ts = ts.shift(1, sequence_length=9) + params = json.loads(ts.provenance(-1).record)["parameters"] + assert params["command"] == "shift" + assert params["value"] == 1 + assert params["sequence_length"] == 9 + + def test_too_negative(self): + ts = tskit.Tree.generate_comb(2).tree_sequence + with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_SEQUENCE_LENGTH"): + ts.shift(-1) + + def test_bad_seq_len(self): + ts = tskit.Tree.generate_comb(2).tree_sequence + with pytest.raises( + tskit.LibraryError, match="TSK_ERR_RIGHT_GREATER_SEQ_LENGTH" + ): + ts.shift(1, sequence_length=1) + + +class TestConcatenate: + def test_simple(self): + ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence + ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence + assert ts1.num_samples == ts2.num_samples + assert ts1.num_nodes != ts2.num_nodes + joint_ts = ts1.concatenate(ts2) + assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 5 + assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length + assert joint_ts.num_samples == ts1.num_samples + ts3 = joint_ts.delete_intervals([[2, 5]]).rtrim() + # Have to simplify here, to remove the redundant nodes + assert ts3.equals(ts1.simplify(), ignore_provenance=True) + ts4 = joint_ts.delete_intervals([[0, 2]]).ltrim() + assert ts4.equals(ts2.simplify(), ignore_provenance=True) + + def test_multiple(self): + np.random.seed(42) + ts3 = [ + tskit.Tree.generate_comb(5, span=2).tree_sequence, + tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence, + tskit.Tree.generate_star(5, span=5).tree_sequence, + ] + for i in range(1, len(ts3)): + # shuffle the sample nodes so they don't have the same IDs + ts3[i] = ts3[i].subset(np.random.permutation(ts3[i].num_nodes)) + assert not np.all(ts3[0].samples() == ts3[1].samples()) + assert not np.all(ts3[0].samples() == ts3[2].samples()) + assert not np.all(ts3[1].samples() == ts3[2].samples()) + ts = ts3[0].concatenate(*ts3[1:]) + assert ts.sequence_length == sum([t.sequence_length for t in ts3]) + assert ts.num_nodes - ts.num_samples == sum( + [t.num_nodes - t.num_samples for t in ts3] + ) + assert np.all(ts.samples() == ts3[0].samples()) + + def test_empty(self): + empty_ts = tskit.TableCollection(10).tree_sequence() + ts = empty_ts.concatenate(empty_ts, empty_ts, empty_ts) + assert ts.num_nodes == 0 + assert ts.sequence_length == 40 + + def test_samples_at_end(self): + ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence + ts2 = tskit.Tree.generate_balanced(5, arity=3, span=3).tree_sequence + # reverse the node order + ts1 = ts1.subset(np.arange(ts1.num_nodes)[::-1]) + assert ts1.num_samples == ts2.num_samples + assert np.all(ts1.samples() != ts2.samples()) + joint_ts = ts1.concatenate(ts2) + assert joint_ts.num_samples == ts1.num_samples + assert np.all(joint_ts.samples() == ts1.samples()) + + def test_internal_samples(self): + tables = tskit.Tree.generate_comb(4, span=2).tree_sequence.dump_tables() + nodes_flags = tables.nodes.flags + nodes_flags[:] = tskit.NODE_IS_SAMPLE + nodes_flags[-1] = 0 # Only root is not a sample + tables.nodes.flags = nodes_flags + ts = tables.tree_sequence() + joint_ts = ts.concatenate(ts) + assert joint_ts.num_samples == ts.num_samples + assert joint_ts.num_nodes == ts.num_nodes + 1 + assert joint_ts.sequence_length == ts.sequence_length * 2 + + def test_some_shared_samples(self): + ts1 = tskit.Tree.generate_comb(4, span=2).tree_sequence + ts2 = tskit.Tree.generate_balanced(8, arity=3, span=3).tree_sequence + shared = np.full(ts2.num_nodes, tskit.NULL) + shared[0] = 1 + shared[1] = 0 + joint_ts = ts1.concatenate(ts2, node_mappings=[shared]) + assert joint_ts.sequence_length == ts1.sequence_length + ts2.sequence_length + assert joint_ts.num_samples == ts1.num_samples + ts2.num_samples - 2 + assert joint_ts.num_nodes == ts1.num_nodes + ts2.num_nodes - 2 + + def test_provenance(self): + ts = tskit.Tree.generate_comb(2).tree_sequence + ts = ts.concatenate(ts, record_provenance=False) + params = json.loads(ts.provenance(-1).record)["parameters"] + assert params["command"] != "concatenate" + + ts = ts.concatenate(ts) + params = json.loads(ts.provenance(-1).record)["parameters"] + assert params["command"] == "concatenate" + + def test_unequal_samples(self): + ts1 = tskit.Tree.generate_comb(5, span=2).tree_sequence + ts2 = tskit.Tree.generate_balanced(4, arity=3, span=3).tree_sequence + with pytest.raises(ValueError, match="must have the same number of samples"): + ts1.concatenate(ts2) + + @pytest.mark.skip( + reason="union bug: https://github.com/tskit-dev/tskit/issues/3168" + ) + def test_duplicate_ts(self): + ts1 = tskit.Tree.generate_comb(3, span=4).tree_sequence + ts = ts1.keep_intervals([[0, 1]]).trim() # a quarter of the original + nm = np.arange(ts.num_nodes) # all nodes identical + ts2 = ts.concatenate(ts, ts, ts, node_mappings=[nm] * 3, add_populations=False) + ts2 = ts2.simplify() # squash the edges + assert ts1.equals(ts2, ignore_provenance=True) + + def test_node_mappings_bad_len(self): + ts = tskit.Tree.generate_comb(3, span=2).tree_sequence + nm = np.arange(ts.num_nodes) + with pytest.raises(ValueError, match="same number of node_mappings"): + ts.concatenate(ts, ts, ts, node_mappings=[nm, nm]) + + class TestMissingData: """ Test various aspects of missing data functionality diff --git a/python/tskit/tables.py b/python/tskit/tables.py index c36cd70822..bc078164c0 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -3991,6 +3991,40 @@ def trim(self, record_provenance=True): record=json.dumps(provenance.get_provenance_dict(parameters)) ) + def shift(self, value, *, sequence_length=None, record_provenance=True): + """ + Shift the coordinate system (used by edges, sites, and migrations) of this + TableCollection by a given value. This is identical to :meth:`TreeSequence.shift` + but acts *in place* to alter the data in this :class:`TableCollection`. + + .. note:: + No attempt is made to check that the new coordinate system or sequence length + is valid: if you wish to do this, use {meth}`TreeSequence.shift` instead. + + :param value: The amount by which to shift the coordinate system. + :param sequence_length: The new sequence length of the tree sequence. If + ``None`` (default) add `value` to the sequence length. + """ + self.drop_index() + self.edges.left += value + self.edges.right += value + self.migrations.left += value + self.migrations.right += value + self.sites.position += value + if sequence_length is None: + self.sequence_length += value + else: + self.sequence_length = sequence_length + if record_provenance: + parameters = { + "command": "shift", + "value": value, + "sequence_length": sequence_length, + } + self.provenances.add_row( + record=json.dumps(provenance.get_provenance_dict(parameters)) + ) + def delete_older(self, time): """ Deletes edge, mutation and migration information at least as old as diff --git a/python/tskit/trees.py b/python/tskit/trees.py index e0100b7e7a..2c8b2eb262 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -32,6 +32,7 @@ import functools import io import itertools +import json import math import numbers import warnings @@ -46,6 +47,7 @@ import tskit.combinatorics as combinatorics import tskit.drawing as drawing import tskit.metadata as metadata_module +import tskit.provenance as provenance import tskit.tables as tables import tskit.text_formats as text_formats import tskit.util as util @@ -7060,6 +7062,104 @@ def trim(self, record_provenance=True): tables.trim(record_provenance) return tables.tree_sequence() + def shift(self, value, sequence_length=None, record_provenance=True): + """ + Shift the coordinate system (used by edges and sites) of this TableCollection by + a given value. Positive values shift the coordinate system to the right, negative + values to the left. The sequence length of the tree sequence will be changed by + ``value``, unless ``sequence_length`` is given, in which case this will be used + for the new sequence length. + + .. note:: + By setting ``value=0``, this method will simply return a tree sequence + with a new sequence length. + + :param value: The amount by which to shift the coordinate system. + :param sequence_length: The new sequence length of the tree sequence. If + ``None`` (default) add ``value`` to the sequence length. + :raises ValueError: If the new coordinate system is invalid (e.g., if + shifting the coordinate system results in negative coordinates). + """ + tables = self.dump_tables() + tables.shift( + value=value, + sequence_length=sequence_length, + record_provenance=record_provenance, + ) + return tables.tree_sequence() + + def concatenate( + self, *args, node_mappings=None, record_provenance=True, add_populations=None + ): + r""" + Concatenate a set of tree sequences to the right of this one, by repeatedly + calling {meth}`union` with an (optional) + node mapping for each of the ``others``. If any node mapping is ``None`` + only map the sample nodes between the input tree sequence and this one, + based on the numerical order of sample node IDs. + + .. note:: + To add gaps between the concatenated tables, use :meth:`shift` or + to remove gaps, use :meth:`trim` before concatenating. + + :param TreeSequence \*args: A list of other tree sequences to append to + the right of this one. + :param Union[list, None] node_mappings: An list of node mappings for each + input tree sequence in ``args``. Each should either be an array of + integers of the same length as the number of nodes in the equivalent + input tree sequence (see :meth:`union` for details), or ``None``. + If ``None``, only sample nodes are mapped to each other. + Default: ``None``, treated as ``[None] * len(args)``. + :param bool record_provenance: If True (default), record details of this + call to ``concatenate`` in the returned tree sequence's provenance + information (Default: True). + :param bool add_populations: If True (default), nodes new to ``self`` will + be assigned new population IDs (see :meth:`union`) + """ + if node_mappings is None: + node_mappings = [None] * len(args) + if add_populations is None: + add_populations = True + if len(node_mappings) != len(args): + raise ValueError( + "You must provide the same number of node_mappings as args" + ) + + samples = self.samples() + tables = self.dump_tables() + tables.drop_index() + + for node_mapping, other in zip(node_mappings, args): + if node_mapping is None: + other_samples = other.samples() + if len(other_samples) != len(samples): + raise ValueError( + "each `other` must have the same number of samples as `self`" + ) + node_mapping = np.full(other.num_nodes, tskit.NULL, dtype=np.int32) + node_mapping[other_samples] = samples + other_tables = other.dump_tables() + other_tables.shift(tables.sequence_length, record_provenance=False) + tables.sequence_length = other_tables.sequence_length + # NB: should we use a different default for add_populations? + tables.union( + other_tables, + node_mapping=node_mapping, + check_shared_equality=False, # Else checks fail with internal samples + record_provenance=False, + add_populations=add_populations, + ) + if record_provenance: + parameters = { + "command": "concatenate", + "TODO": "add concatenate parameters", # tricky as both have provenances + } + tables.provenances.add_row( + record=json.dumps(provenance.get_provenance_dict(parameters)) + ) + + return tables.tree_sequence() + def split_edges(self, time, *, flags=None, population=None, metadata=None): """ Returns a copy of this tree sequence in which we replace any