Skip to content

Commit 3f4c0ed

Browse files
committed
Add/refine tests, draft docstring
Clean up dimension handling around summary functions and normalisation. There is a slight speed advantage (according to a microbenchmark) and a huge readability advantage to simply returning [value]. I keep all computations specifying `keepdims`, but remove list indexing (i.e. `AB[[0]]`) in favor of returning a list with a single scalar. It turns out that vectorised numpy functions are actually slower in some cases because the data we're operating on is so small. Finally, fix the default normalisation function so that it works both on one-way and two-way statistics. Users will still need to specify `hap_norm` when appropriate (and a special case of `hap_norm` for two-way stats). Per Peter's comment, I investigated dimension dropping and indeed, general stats don't drop dimensions so I removed the dimension dropping code. However, we return a matrix of `(m, m, k)` and we want `(k, m, m)`, so `np.moveaxis` is still needed. Added tests: * Multiallelic multi sample-set. This tests operations on two sample sets for multiallelic data (which excercises the norm function with multiple sample sets). This test highlighted the slight changes needed to the default normalisation function. * Multi outputs. This test mimics a two-way stat called on multiple indexes. It shows and tests the ability to compute multiple statistics from the same haplotype counts matrix (which is especially useful with the explosion of possible summary functions in three-way, four-way stats). In our biallelic test case, I also assert that the normalisation function is never called and add a note about polarisation. Finally, I add a draft docstring, but to complete this I think that the two-locus docs are required. Also, I'd like to add some general documentation.
1 parent 4c04ff3 commit 3f4c0ed

2 files changed

Lines changed: 170 additions & 29 deletions

File tree

python/tests/test_ld_matrix.py

Lines changed: 98 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2400,16 +2400,19 @@ def test_multipopulation_r2_varying_unequal_set_sizes(genotypes, sample_sets, ex
24002400

24012401
class GeneralStatFuncs:
24022402
"""
2403-
functions take X, n as parameters where
2403+
Summary functions take X, n as parameters where X is a matrix of haplotype
2404+
counts per sample set and n is a vector of sample set sizes. X has shape (3, k)
2405+
and n has shape (k, ), where k is the number of sample sets. The rows of X
2406+
contain haplotype counts for AB, Ab, aB (capitalized == derived).
24042407
2405-
X: shape=(3, #ss)
2408+
X: shape=(3, k)
24062409
sample sets
2407-
count AB [[ ]
2408-
count Ab [ ]
2409-
count aB [ ]]
2410+
count AB [[ #ss1, #ss2, ... ]
2411+
count Ab [ #ss1, #ss2, ... ]
2412+
count aB [ #ss1, #ss2, ... ]]
24102413
2411-
n: shape=(#ss, )
2412-
[ ]
2414+
n: shape=(k, )
2415+
[ #ss1, #ss2, ... ]
24132416
"""
24142417

24152418
@staticmethod
@@ -2480,37 +2483,39 @@ def pi2(X, n):
24802483
def D2_unbiased(X, n):
24812484
AB, Ab, aB = X
24822485
ab = n - X.sum(0)
2483-
return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * (
2486+
return (
24842487
((aB**2) * (Ab - 1) * Ab)
24852488
+ ((ab - 1) * ab * (AB - 1) * AB)
24862489
- (aB * Ab * (Ab + (2 * ab * AB) - 1))
2487-
)
2490+
) / (n * (n - 1) * (n - 2) * (n - 3))
24882491

24892492
@staticmethod
24902493
def Dz_unbiased(X, n):
24912494
AB, Ab, aB = X
24922495
ab = n - X.sum(0)
2493-
return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * (
2496+
return (
24942497
(((AB * ab) - (Ab * aB)) * (aB + ab - AB - Ab) * (Ab + ab - AB - aB))
24952498
- ((AB * ab) * (AB + ab - Ab - aB - 2))
24962499
- ((Ab * aB) * (Ab + aB - AB - ab - 2))
2497-
)
2500+
) / (n * (n - 1) * (n - 2) * (n - 3))
24982501

24992502
@staticmethod
25002503
def pi2_unbiased(X, n):
25012504
AB, Ab, aB = X
25022505
ab = n - X.sum(0)
2503-
return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * (
2506+
return (
25042507
((AB + Ab) * (aB + ab) * (AB + aB) * (Ab + ab))
25052508
- ((AB * ab) * (AB + ab + (3 * Ab) + (3 * aB) - 1))
25062509
- ((Ab * aB) * (Ab + aB + (3 * AB) + (3 * ab) - 1))
2507-
)
2510+
) / (n * (n - 1) * (n - 2) * (n - 3))
25082511

2512+
# Two-way statistics have the _ij suffix.
25092513
@staticmethod
25102514
def r2_ij(X, n):
25112515
pAB, pAb, paB = X / n
25122516
pA = pAb + pAB
25132517
pB = paB + pAB
2518+
# keepdims preserves the output shape of (1, )
25142519
D2_ij = np.prod(pAB - (pA * pB), keepdims=True)
25152520
denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB)), keepdims=True)
25162521
with suppress_overflow_div0_warning():
@@ -2525,17 +2530,37 @@ def D2_ij(X, n):
25252530

25262531
@staticmethod
25272532
def D2_ij_unbiased(X, n):
2528-
"""NB: We use double brackets here to preserve the output shape of (1,)"""
2533+
"""The identity of the sample sets is up to the user."""
25292534
AB, Ab, aB = X
25302535
ab = n - X.sum(0)
2531-
return (
2532-
(Ab[[0]] * aB[[0]] - AB[[0]] * ab[[0]])
2533-
* (Ab[[1]] * aB[[1]] - AB[[1]] * ab[[1]])
2534-
/ n[[0]]
2535-
/ (n[[0]] - 1)
2536-
/ n[[1]]
2537-
/ (n[[1]] - 1)
2536+
return [
2537+
(Ab[0] * aB[0] - AB[0] * ab[0])
2538+
* (Ab[1] * aB[1] - AB[1] * ab[1])
2539+
/ (n[0] * (n[0] - 1) * n[1] * (n[1] - 1))
2540+
]
2541+
2542+
@staticmethod
2543+
def D2_ii_ij_jj_unbiased(X, n):
2544+
"""
2545+
Multiple stats can be computed from the same data. The identity of the
2546+
sample sets is up to the user. This function assumes two sample sets.
2547+
"""
2548+
AB, Ab, aB = X
2549+
ab = n - X.sum(0)
2550+
2551+
# unbiased estimator for equal sample sets
2552+
ii, jj = (
2553+
AB * (AB - 1) * ab * (ab - 1)
2554+
+ Ab * (Ab - 1) * aB * (aB - 1)
2555+
- 2 * AB * Ab * aB * ab
2556+
) / (n * (n - 1) * (n - 2) * (n - 3))
2557+
# unbiased estimator for disjoint sample sets
2558+
ij = (
2559+
(Ab[0] * aB[0] - AB[0] * ab[0])
2560+
* (Ab[1] * aB[1] - AB[1] * ab[1])
2561+
/ (n[0] * (n[0] - 1) * n[1] * (n[1] - 1))
25382562
)
2563+
return [ii, ij, jj]
25392564

25402565

25412566
@pytest.fixture(scope="module")
@@ -2573,7 +2598,17 @@ def ts_multiallelic_fixture():
25732598
def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture):
25742599
ts = ts_100_samp_with_sites_fixture
25752600
sample_sets = [ts.samples()[0:50], ts.samples()[50:100]]
2576-
ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2)
2601+
2602+
# In addition to not needing a normalisation function, normalisation is also
2603+
# not required because these sites are biallelic.
2604+
def assert_no_norm_func(*_):
2605+
raise Exception(
2606+
"Normalisation function should not be called for biallelic sites"
2607+
)
2608+
2609+
ldg = ts.two_locus_count_stat(
2610+
sample_sets, getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func
2611+
)
25772612
ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat)
25782613
np.testing.assert_array_almost_equal(ldg, ld)
25792614

@@ -2584,7 +2619,7 @@ def test_general_two_locus_two_way_site_stat(stat, ts_100_samp_with_sites_fixtur
25842619
sample_sets = [ts.samples()[0:50], ts.samples()[50:100]]
25852620
ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1)
25862621
ld = ts.ld_matrix(
2587-
sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=(0, 1)
2622+
sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=[(0, 1)]
25882623
)
25892624
np.testing.assert_array_almost_equal(ldg, ld)
25902625

@@ -2599,7 +2634,24 @@ def test_general_one_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixtu
25992634
[ts.samples()], general_func, 1, norm_f=norm_func, polarised=polarised
26002635
)
26012636
ld = ts.ld_matrix(stat=stat)
2602-
np.testing.assert_array_almost_equal(ld, ldg)
2637+
# ld_matrix drops dims, expand for comparison
2638+
np.testing.assert_array_almost_equal(ldg, np.expand_dims(ld, 0))
2639+
2640+
2641+
@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys())
2642+
def test_general_one_way_two_locus_stat_multiallelic_multi_sample_set(
2643+
stat, ts_multiallelic_fixture
2644+
):
2645+
ts = ts_multiallelic_fixture
2646+
general_func = getattr(GeneralStatFuncs, stat)
2647+
norm_func = (lambda X, n, nA, nB: X[0] / n) if stat == "r2" else None
2648+
polarised = POLARIZATION[SUMMARY_FUNCS[stat]]
2649+
sample_sets = [ts.samples(), ts.samples()]
2650+
ldg = ts.two_locus_count_stat(
2651+
sample_sets, general_func, 2, norm_f=norm_func, polarised=polarised
2652+
)
2653+
ld = ts.ld_matrix(stat=stat, sample_sets=sample_sets)
2654+
np.testing.assert_array_almost_equal(ldg, ld)
26032655

26042656

26052657
@pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"])
@@ -2616,4 +2668,25 @@ def test_general_two_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixtu
26162668
ld = ts.ld_matrix(
26172669
stat=stat.replace("_ij", ""), indexes=(0, 1), sample_sets=sample_sets
26182670
)
2619-
np.testing.assert_array_almost_equal(ld, ldg)
2671+
# ld_matrix drops dims, expand for comparison
2672+
np.testing.assert_array_almost_equal(ldg, np.expand_dims(ld, 0))
2673+
2674+
2675+
def test_general_two_locus_multi_outputs():
2676+
ts = msprime.sim_mutations(
2677+
msprime.sim_ancestry(
2678+
4, recombination_rate=0.1, sequence_length=100, random_seed=123
2679+
),
2680+
rate=0.1,
2681+
random_seed=123,
2682+
)
2683+
assert ts.num_samples == 8, "8 samples are required"
2684+
assert max({len(s.mutations) for s in ts.sites()}) > 2, (
2685+
"At least one multiallelic site required"
2686+
)
2687+
A = ts.samples()[0:4]
2688+
B = ts.samples()[4:]
2689+
2690+
ldg = ts.two_locus_count_stat([A, B], GeneralStatFuncs.D2_ii_ij_jj_unbiased, 3)
2691+
ld = ts.ld_matrix([A, B], stat="D2_unbiased", indexes=[(0, 0), (0, 1), (1, 1)])
2692+
np.testing.assert_array_almost_equal(ldg, ld)

python/tskit/trees.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10942,14 +10942,84 @@ def two_locus_count_stat(
1094210942
positions=None,
1094310943
mode="site",
1094410944
):
10945+
"""
10946+
Compute two-locus statistics with a user-defined python function that
10947+
operates on haplotype counts. TODO: reference modes in two-locus docs.
10948+
On each pair of sites or trees, the summary function is provided with
10949+
``X``, a matrix with shape (3, k) and ``n``, a vector with shape (k,),
10950+
where k is the number of sample sets provided. ``X`` is a read-only
10951+
matrix whose rows contain haplotype counts per sample set (counts of AB,
10952+
Ab, aB) and ``n`` is a vector of sample set sizes.
10953+
10954+
.. note::
10955+
Because we are operating on very small matrices/vectors, vectorised
10956+
operations are often times slower than operations on scalars. Simply
10957+
returning ``[value]`` can be faster than returning
10958+
``value[np.newaxis,]`` or ``np.expand_dims(value, 0)``.
10959+
10960+
What follows is an example of computing ``D`` from a tree sequence. Many
10961+
more examples can be found in the test suite
10962+
``test_ld_matrix.py::GeneralStatsFuncs``. Let's begin with our summary
10963+
function, ``D``. We convert counts to proportions, then compute ``D``,
10964+
returning a numpy array with length equal to the number of sample sets.
10965+
10966+
.. code-block:: python
10967+
def D(X, n):
10968+
pAB, pAb, paB = X / n
10969+
pA = pAb + pAB
10970+
pB = paB + pAB
10971+
return pAB - (pA * pB)
10972+
10973+
``norm_f`` is a normalisation function used to combine all computed
10974+
statistics for multiallelic allele pairs (TODO: see two-locus
10975+
docs). Biallelic sites do not require any normalisation (in fact, the
10976+
normalisation function is never called for biallelic sites). If one of
10977+
either site A or site B is multiallelic, then the normalisation function
10978+
will be called. The default normalisation function is identical to
10979+
``total_norm`` shown in the example below. ``hap_norm`` is required for
10980+
normalising :math:`r^2`. Both of these examples return a numpy array
10981+
with length equal to the number of sample sets (for one-way stats).
10982+
10983+
.. code-block:: python
10984+
def total_norm(X, n, nA, nB):
10985+
[1 / (nA * nB)] * result_dim
10986+
10987+
def hap_norm(X, n, nA, nB):
10988+
X[0] / n
10989+
10990+
A simple call (without specifying normalisation) would look like this
10991+
10992+
.. code-block::python
10993+
ts.two_locus_count_stat([ts.samples()], D, 1, polarised=True)
10994+
10995+
:param list sample_sets: A list of lists of Node IDs, specifying the
10996+
groups of nodes to compute the statistic with.
10997+
:param f: A function that takes two arguments - a two-dimensional array
10998+
with shape (3, k) and a one-dimensional array with shape (k, ) where
10999+
k is the number of sample sets.
11000+
:param int result_dim: The length of ``f`` and ``norm_f``'s return value.
11001+
:param norm_f: A function that takes four arguments - the first two are
11002+
the same as ``f``, the second two are scalars representing the
11003+
number of A and B alleles, respectively.
11004+
:param bool polarised: Whether to leave the ancestral state out of
11005+
computations: see :ref:`sec_stats` for more details.
11006+
:param list sites: TODO: two-locus docs
11007+
:param list positions: TODO: two-locus docs
11008+
:param str mode: A string giving the "type" of the statistic to be
11009+
computed (defaults to "site").
11010+
:return: A ndarray with shape equal to (TODO: reference two-locus docs,
11011+
no dimension dropping shape=(k, m, m) where k=num_sample_sets,
11012+
m=num_sites or num_trees).
11013+
"""
1094511014
row_sites, col_sites = self.parse_sites(sites)
1094611015
row_positions, col_positions = self.parse_positions(positions)
1094711016
_, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets)
1094811017
result = self._ll_tree_sequence.two_locus_count_stat(
1094911018
sample_set_sizes,
1095011019
sample_sets,
1095111020
f,
10952-
norm_f or (lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]),
11021+
# produce the same number of dims as output dimensions
11022+
norm_f or (lambda X, n, nA, nB: [1 / (nA * nB)] * result_dim),
1095311023
result_dim,
1095411024
polarised,
1095511025
row_sites,
@@ -10958,11 +11028,9 @@ def two_locus_count_stat(
1095811028
col_positions,
1095911029
mode,
1096011030
)
10961-
if result_dim == 1: # drop dimension
10962-
return result.reshape(result.shape[:2])
1096311031
# Orient the data so that the first dimension is the sample set so that
1096411032
# we get one LD matrix per sample set.
10965-
return result.swapaxes(0, 2).swapaxes(1, 2)
11033+
return np.moveaxis(result, -1, 0)
1096611034

1096711035
def ld_matrix(
1096811036
self,

0 commit comments

Comments
 (0)