@@ -2511,34 +2511,30 @@ def r2_ij(X, n):
25112511 pAB , pAb , paB = X / n
25122512 pA = pAb + pAB
25132513 pB = paB + pAB
2514- D2_ij = np .prod (pAB - (pA * pB ))
2515- denom = np .prod (np .sqrt (pA * pB * (1 - pA ) * (1 - pB )))
2514+ D2_ij = np .prod (pAB - (pA * pB ), keepdims = True )
2515+ denom = np .prod (np .sqrt (pA * pB * (1 - pA ) * (1 - pB )), keepdims = True )
25162516 with suppress_overflow_div0_warning ():
2517- return np . expand_dims ( D2_ij / denom , axis = 0 )
2517+ return D2_ij / denom
25182518
25192519 @staticmethod
25202520 def D2_ij (X , n ):
25212521 pAB , pAb , paB = X / n
25222522 pA = pAb + pAB
25232523 pB = paB + pAB
2524- return np .expand_dims ( np . prod (pAB - (pA * pB )), axis = 0 )
2524+ return np .prod (pAB - (pA * pB ), keepdims = True )
25252525
25262526 @staticmethod
25272527 def D2_ij_unbiased (X , n ):
2528- """
2529- NB: the two sample sets must be disjoint
2530- we have no way for testing equality
2531- """
2528+ """NB: We use double brackets here to preserve the output shape of (1,)"""
25322529 AB , Ab , aB = X
25332530 ab = n - X .sum (0 )
2534- return np .expand_dims (
2535- (Ab [0 ] * aB [0 ] - AB [0 ] * ab [0 ])
2536- * (Ab [1 ] * aB [1 ] - AB [1 ] * ab [1 ])
2537- / n [0 ]
2538- / (n [0 ] - 1 )
2539- / n [1 ]
2540- / (n [1 ] - 1 ),
2541- axis = 0 ,
2531+ return (
2532+ (Ab [[0 ]] * aB [[0 ]] - AB [[0 ]] * ab [[0 ]])
2533+ * (Ab [[1 ]] * aB [[1 ]] - AB [[1 ]] * ab [[1 ]])
2534+ / n [[0 ]]
2535+ / (n [[0 ]] - 1 )
2536+ / n [[1 ]]
2537+ / (n [[1 ]] - 1 )
25422538 )
25432539
25442540
@@ -2624,7 +2620,7 @@ def test_general_one_way_two_locus_stat_multiallelic(stat):
26242620 elif stat in {"D" , "r" , "D_prime" }:
26252621 result = ts .two_locus_count_stat ([ts .samples ()], func , 1 , polarised = True )
26262622 else :
2627- # default norm func is lambda X, n, nA, nB: np.expand_dims( 1 / (nA * nB), axis=0)
2623+ # default norm func is ` lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]`
26282624 result = ts .two_locus_count_stat ([ts .samples ()], func , 1 )
26292625 np .testing .assert_array_almost_equal (ts .ld_matrix (stat = stat ), result )
26302626
@@ -2641,13 +2637,14 @@ def test_general_two_way_two_locus_stat_multiallelic(stat):
26412637 ts = tsutil .all_fields_ts ()
26422638 func = getattr (GeneralStatFuncs , stat )
26432639 if stat == "r2_ij" :
2644-
2645- def norm_f (X , n , nA , nB ):
2646- return np .expand_dims (X [0 ].sum () / n .sum (), axis = 0 )
2647-
2648- result = ts .two_locus_count_stat ([ts .samples (), ts .samples ()], func , 1 , norm_f )
2640+ result = ts .two_locus_count_stat (
2641+ [ts .samples (), ts .samples ()],
2642+ func ,
2643+ 1 ,
2644+ lambda X , n , nA , nB : X [0 ].sum (keepdims = True ) / n .sum (),
2645+ )
26492646 else :
2650- # default norm func is lambda X, n, nA, nB: np.expand_dims( 1 / (nA * nB), axis=0)
2647+ # default norm func is ` lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]`
26512648 result = ts .two_locus_count_stat ([ts .samples (), ts .samples ()], func , 1 )
26522649 np .testing .assert_array_almost_equal (
26532650 ts .ld_matrix (
0 commit comments