Skip to content

Commit bdc8092

Browse files
committed
preserve native dimensions instead of expanding at the end
1 parent 0d8c5d3 commit bdc8092

2 files changed

Lines changed: 21 additions & 24 deletions

File tree

python/tests/test_ld_matrix.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2511,34 +2511,30 @@ def r2_ij(X, n):
25112511
pAB, pAb, paB = X / n
25122512
pA = pAb + pAB
25132513
pB = paB + pAB
2514-
D2_ij = np.prod(pAB - (pA * pB))
2515-
denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB)))
2514+
D2_ij = np.prod(pAB - (pA * pB), keepdims=True)
2515+
denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB)), keepdims=True)
25162516
with suppress_overflow_div0_warning():
2517-
return np.expand_dims(D2_ij / denom, axis=0)
2517+
return D2_ij / denom
25182518

25192519
@staticmethod
25202520
def D2_ij(X, n):
25212521
pAB, pAb, paB = X / n
25222522
pA = pAb + pAB
25232523
pB = paB + pAB
2524-
return np.expand_dims(np.prod(pAB - (pA * pB)), axis=0)
2524+
return np.prod(pAB - (pA * pB), keepdims=True)
25252525

25262526
@staticmethod
25272527
def D2_ij_unbiased(X, n):
2528-
"""
2529-
NB: the two sample sets must be disjoint
2530-
we have no way for testing equality
2531-
"""
2528+
"""NB: We use double brackets here to preserve the output shape of (1,)"""
25322529
AB, Ab, aB = X
25332530
ab = n - X.sum(0)
2534-
return np.expand_dims(
2535-
(Ab[0] * aB[0] - AB[0] * ab[0])
2536-
* (Ab[1] * aB[1] - AB[1] * ab[1])
2537-
/ n[0]
2538-
/ (n[0] - 1)
2539-
/ n[1]
2540-
/ (n[1] - 1),
2541-
axis=0,
2531+
return (
2532+
(Ab[[0]] * aB[[0]] - AB[[0]] * ab[[0]])
2533+
* (Ab[[1]] * aB[[1]] - AB[[1]] * ab[[1]])
2534+
/ n[[0]]
2535+
/ (n[[0]] - 1)
2536+
/ n[[1]]
2537+
/ (n[[1]] - 1)
25422538
)
25432539

25442540

@@ -2624,7 +2620,7 @@ def test_general_one_way_two_locus_stat_multiallelic(stat):
26242620
elif stat in {"D", "r", "D_prime"}:
26252621
result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True)
26262622
else:
2627-
# default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0)
2623+
# default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]`
26282624
result = ts.two_locus_count_stat([ts.samples()], func, 1)
26292625
np.testing.assert_array_almost_equal(ts.ld_matrix(stat=stat), result)
26302626

@@ -2641,13 +2637,14 @@ def test_general_two_way_two_locus_stat_multiallelic(stat):
26412637
ts = tsutil.all_fields_ts()
26422638
func = getattr(GeneralStatFuncs, stat)
26432639
if stat == "r2_ij":
2644-
2645-
def norm_f(X, n, nA, nB):
2646-
return np.expand_dims(X[0].sum() / n.sum(), axis=0)
2647-
2648-
result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f)
2640+
result = ts.two_locus_count_stat(
2641+
[ts.samples(), ts.samples()],
2642+
func,
2643+
1,
2644+
lambda X, n, nA, nB: X[0].sum(keepdims=True) / n.sum(),
2645+
)
26492646
else:
2650-
# default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0)
2647+
# default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]`
26512648
result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1)
26522649
np.testing.assert_array_almost_equal(
26532650
ts.ld_matrix(

python/tskit/trees.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10936,7 +10936,7 @@ def two_locus_count_stat(
1093610936
sample_sets,
1093710937
f,
1093810938
result_dim,
10939-
norm_f=lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0),
10939+
norm_f=lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,],
1094010940
polarised=False,
1094110941
sites=None,
1094210942
positions=None,

0 commit comments

Comments
 (0)