Skip to content

Commit 2099a25

Browse files
committed
Add mutations derived state
1 parent 009f521 commit 2099a25

2 files changed

Lines changed: 75 additions & 30 deletions

File tree

python/_tskitmodule.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11195,6 +11195,24 @@ TreeSequence_get_mutations_time(TreeSequence *self, void *closure)
1119511195
return ret;
1119611196
}
1119711197

11198+
#if HAVE_NUMPY_2
11199+
static PyObject *
11200+
TreeSequence_get_mutations_derived_state(TreeSequence *self, void *closure)
11201+
{
11202+
PyObject *ret = NULL;
11203+
tsk_mutation_table_t mutations;
11204+
11205+
if (TreeSequence_check_state(self) != 0) {
11206+
goto out;
11207+
}
11208+
mutations = self->tree_sequence->tables->mutations;
11209+
ret = TreeSequence_decode_ragged_string_column(
11210+
self, mutations.num_rows, mutations.derived_state, mutations.derived_state_offset);
11211+
out:
11212+
return ret;
11213+
}
11214+
#endif
11215+
1119811216
static PyObject *
1119911217
TreeSequence_get_mutations_metadata(TreeSequence *self, void *closure)
1120011218
{
@@ -11792,6 +11810,11 @@ static PyGetSetDef TreeSequence_getsetters[] = {
1179211810
{ .name = "mutations_time",
1179311811
.get = (getter) TreeSequence_get_mutations_time,
1179411812
.doc = "The mutation time array" },
11813+
#if HAVE_NUMPY_2
11814+
{ .name = "mutations_derived_state",
11815+
.get = (getter) TreeSequence_get_mutations_derived_state,
11816+
.doc = "The mutation derived state array" },
11817+
#endif
1179511818
{ .name = "mutations_metadata",
1179611819
.get = (getter) TreeSequence_get_mutations_metadata,
1179711820
.doc = "The mutation metadata array" },

python/tests/test_lowlevel.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,69 +1976,91 @@ def test_individuals_nodes(self, ts_fixture):
19761976

19771977
@pytest.mark.skipif(not _tskit.HAS_NUMPY_2, reason="Requires NumPy 2.0+")
19781978
@pytest.mark.parametrize(
1979-
"site_lengths",
1979+
"string_array", ["sites_ancestral_state", "mutations_derived_state"]
1980+
)
1981+
@pytest.mark.parametrize(
1982+
"str_lengths",
19801983
["none", "all-0", "all-1", "all-2", "mixed", "very_long", "unicode"],
19811984
)
1982-
def test_sites_ancestral_state(self, ts_fixture, site_lengths):
1983-
if site_lengths == "none":
1985+
def test_string_arrays(self, ts_fixture, str_lengths, string_array):
1986+
if str_lengths == "none":
19841987
ts = tskit.TableCollection(1.0).tree_sequence()
19851988
else:
1986-
if site_lengths == "all-1":
1989+
if str_lengths == "all-1":
19871990
ts = ts_fixture
1988-
assert {len(site.ancestral_state) for site in ts.sites()} == {1}
1991+
if string_array == "sites_ancestral_state":
1992+
assert ts.num_sites > 0
1993+
assert {len(site.ancestral_state) for site in ts.sites()} == {1}
1994+
elif string_array == "mutations_derived_state":
1995+
assert ts.num_mutations > 0
1996+
assert {len(mut.derived_state) for mut in ts.mutations()} == {1}
19891997
else:
19901998
tables = ts_fixture.dump_tables()
1991-
sites = tables.sites.copy()
1992-
tables.sites.clear()
1993-
1994-
ancestral_state_map = {
1995-
"all-0": lambda i, site: "",
1996-
"all-2": lambda i, site: chr(ord("A") + (i % 26)) * 2,
1997-
"mixed": lambda i, site: chr(ord("A") + (i % 26)) * (i % 20),
1998-
"very_long": lambda i, site: "A" * 100_000_000 if i == 1 else "A",
1999-
"unicode": lambda i, site: "💩" * (i + 1),
1999+
2000+
str_map = {
2001+
"all-0": lambda i, item: "",
2002+
"all-2": lambda i, item: chr(ord("A") + (i % 26)) * 2,
2003+
"mixed": lambda i, item: chr(ord("A") + (i % 26)) * (i % 20),
2004+
"very_long": lambda i, item: "A" * 100_000_000 if i == 1 else "T",
2005+
"unicode": lambda i, item: "🧬" * (i + 1),
20002006
}
20012007

2002-
get_ancestral_state = ancestral_state_map[site_lengths]
2003-
for i, site in enumerate(sites):
2004-
tables.sites.append(
2005-
site.replace(ancestral_state=get_ancestral_state(i, site))
2006-
)
2008+
if string_array == "sites_ancestral_state":
2009+
sites = tables.sites.copy()
2010+
tables.sites.clear()
2011+
get_ancestral_state = str_map[str_lengths]
2012+
for i, site in enumerate(sites):
2013+
tables.sites.append(
2014+
site.replace(ancestral_state=get_ancestral_state(i, site))
2015+
)
2016+
elif string_array == "mutations_derived_state":
2017+
mutations = tables.mutations.copy()
2018+
tables.mutations.clear()
2019+
get_derived_state = str_map[str_lengths]
2020+
for i, mutation in enumerate(mutations):
2021+
tables.mutations.append(
2022+
mutation.replace(
2023+
derived_state=get_derived_state(i, mutation)
2024+
)
2025+
)
2026+
20072027
ts = tables.tree_sequence()
20082028
ll_ts = ts.ll_tree_sequence
20092029

2010-
a = ll_ts.sites_ancestral_state
2030+
a = getattr(ll_ts, string_array)
2031+
20112032
# Contents
2012-
if site_lengths == "none":
2033+
if str_lengths == "none":
20132034
assert a.size == 0
20142035
else:
2015-
for site in ts.sites():
2016-
assert a[site.id] == site.ancestral_state
2036+
if string_array == "sites_ancestral_state":
2037+
for site in ts.sites():
2038+
assert a[site.id] == site.ancestral_state
2039+
elif string_array == "mutations_derived_state":
2040+
for mutation in ts.mutations():
2041+
assert a[mutation.id] == mutation.derived_state
20172042

20182043
# Read only
20192044
with pytest.raises(AttributeError, match="not writable"):
2020-
ll_ts.sites_ancestral_state = None
2045+
setattr(ll_ts, string_array, None)
20212046
with pytest.raises(AttributeError, match="not writable"):
2022-
del ll_ts.sites_ancestral_state
2047+
delattr(ll_ts, string_array)
20232048

20242049
with pytest.raises(ValueError, match="assignment destination"):
20252050
a[:] = 0
20262051
with pytest.raises(ValueError, match="assignment destination"):
20272052
a[0] = 0
2028-
if site_lengths in [("all-1",)]:
2029-
with pytest.raises(ValueError, match="cannot set WRITEABLE"):
2030-
a.setflags(write=True)
20312053

20322054
# Properties
20332055
assert a.dtype == np.dtypes.StringDType()
20342056
assert a.flags.aligned
20352057
assert a.flags.c_contiguous
2036-
b = ll_ts.sites_ancestral_state
2058+
b = getattr(ll_ts, string_array)
20372059
assert a is not b
20382060
assert np.all(a == b)
20392061

20402062
# Lifetime
2041-
a1 = ll_ts.sites_ancestral_state
2063+
a1 = getattr(ll_ts, string_array)
20422064
a2 = a1.copy()
20432065
assert a1 is not a2
20442066
del ll_ts

0 commit comments

Comments
 (0)