From 9f20af27df12dd2aebbf2f5e1831608392956ee5 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 2 Nov 2025 02:26:32 -0600 Subject: [PATCH 1/7] clarify semantics of searchsorted --- c/tskit/core.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c/tskit/core.c b/c/tskit/core.c index 5e5f828943..103bf041a8 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -853,7 +853,7 @@ tsk_blkalloc_free(tsk_blkalloc_t *self) } } -/* Mirrors the semantics of numpy's searchsorted function. Uses binary +/* Mirrors the semantics of numpy's searchsorted function (side='left'). Uses binary * search to find the index of the closest value in the array. */ tsk_size_t tsk_search_sorted(const double *restrict array, tsk_size_t size, double value) From 0afa4cec471e3d59044695cb4588ba076d0848cf Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 2 Nov 2025 02:33:06 -0600 Subject: [PATCH 2/7] first cut --- c/tskit/trees.c | 472 +++++++++++++++++++++++++++++++++++++++++- c/tskit/trees.h | 46 +++- python/_tskitmodule.c | 167 +++++++++++++++ python/tskit/trees.py | 87 +++++++- 4 files changed, 756 insertions(+), 16 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index f00bb83d28..675d6244b1 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3461,6 +3461,369 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl return ret; } +/* Find the distance span between two trees (ie the greatest and smallest distance that + * can be represented by the subtraction of two tree's intervals. + * NB: required condition: ivl_1 <= ivl_2 */ +static inline void +get_distance_bounds(const interval_t i1, const interval_t i2, interval_t *out) +{ + // Equal intervals will stretch into the negative. + // TODO: any others? If not, the following will suffice. + // // If the intervals are equal + // if (i1.left == i2.left && i1.right == i2.right) { + // out->left = 0; + // out->right = i1.right; + // return; + // } + out->left = fmax(0, i2.left - i1.right); + out->right = i2.right - i1.left; +} + +// TODO: breaks at 0,1/0,1 -- span is into negative, no need to 1/2 stat. +static double +integrate_stat_over_window(const interval_t i1, const interval_t i2, + const interval_t bounds, double wl, double wr, double stat) +{ + double r2_len = fmin(i1.right - i1.left, i2.right - i2.left); + // Size of the center region is determined by the larger of the two + // intervals. It is zero if they are equal (triangle). + double r2_l_bound = fmin(i2.left - i1.left, i2.right - i1.right); + double r2_r_bound = bounds.right - r2_len; + // left and right values for each of the 3 regions to integrate over + // variable names are: r{region}_{left|right} + double r1_l = fmin(fmax(wl, bounds.left), r2_l_bound); + double r1_r = fmax(fmin(wr, r2_l_bound), bounds.left); + double r2_l = fmin(fmax(wl, r2_l_bound), r2_r_bound); + double r2_r = fmax(fmin(wr, r2_r_bound), r2_l_bound); + double r3_l = fmin(fmax(wl, r2_r_bound), bounds.right); + double r3_r = fmax(fmin(wr, bounds.right), r2_r_bound); + double i1_span = i1.right - i1.left; + double i2_span = i2.right - i2.left; + // double s = (stat / i1_span) * (stat / i2_span) + // * (-.5 * (r1_l - r1_r) * (2. * i1.right - 2. * i2.left + r1_l + r1_r) + // + (r2_r - r2_l) * r2_len + // + .5 * (r3_l - r3_r) * (2 * i1.left - 2. * i2.right + r3_l + + // r3_r)); + + double s + = (stat / i1_span) * (stat / i2_span) + * (-1. / 2 * (r1_l - r1_r) * (2 * i1.right - 2 * i2.left + r1_l + r1_r) + + (r2_r - r2_l) * r2_len + + 1. / 2 * (r3_l - r3_r) * (2 * i1.left - 2 * i2.right + r3_l + r3_r)); + // printf("%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\n", s, r2_len, r2_l_bound, + // r2_r_bound, r1_l, r1_r, r2_l, r2_r, r3_l, r3_r, i1_span, i2_span); + // printf("%f\n", s); + return s; +} + +static int +tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + sample_count_stat_params_t *f_params, const double *bins, tsk_size_t num_bins, + double *result) +{ + int ret = 0; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + interval_t bounds, ivl_l, ivl_r; + iter_state l_state, r_state; + double *result_tmp = NULL, *result_row; + tsk_bitset_t node_samples, sample_sets_bits; + tsk_size_t i, j, k, bin_l, bin_r, *bincount = NULL, *bincount_row; + + tsk_memset(&sample_sets_bits, 0, sizeof(sample_sets_bits)); + tsk_memset(&node_samples, 0, sizeof(node_samples)); + tsk_memset(&l_state, 0, sizeof(l_state)); + tsk_memset(&r_state, 0, sizeof(r_state)); + result_tmp = tsk_malloc(result_dim * sizeof(*result_tmp)); + bincount = tsk_calloc(result_dim * (num_bins - 1), sizeof(*bincount)); + if (result_tmp == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = iter_state_init(&l_state, self, state_dim); + if (ret != 0) { + goto out; + } + ret = iter_state_init(&r_state, self, state_dim); + if (ret != 0) { + goto out; + } + ret = sample_sets_to_bitset( + self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits); + if (ret != 0) { + goto out; + } + ret = get_node_samples(self, state_dim, &sample_sets_bits, &node_samples); + if (ret != 0) { + goto out; + } + iter_state_clear(&l_state, state_dim, num_nodes, &node_samples); + // TODO: bin skipping based on range + for (i = 0; i < self->num_trees; i++) { + tsk_memset(result_tmp, 0, result_dim * sizeof(*result_tmp)); + iter_state_clear(&r_state, state_dim, num_nodes, &node_samples); + // TODO: check distance and continue if too short + // TODO: verify my assumptions that this is merely a setup step and we can + // discard the stat value + ret = advance_collect_edges(&l_state, (tsk_id_t) i); + if (ret != 0) { + goto out; + } + ivl_l = l_state.tree.tree_pos.interval; + ret = compute_two_tree_branch_stat( + self, &r_state, &l_state, f, f_params, result_dim, state_dim, result_tmp); + if (ret != 0) { + goto out; + } + for (j = i; j < self->num_trees; j++) { + ret = advance_collect_edges(&r_state, (tsk_id_t) j); + if (ret != 0) { + goto out; + } + ivl_r = r_state.tree.tree_pos.interval; + get_distance_bounds(ivl_l, ivl_r, &bounds); + bin_l = tsk_search_sorted(bins + 1, num_bins - 1, bounds.left); + bin_r = tsk_search_sorted(bins + 1, num_bins - 1, bounds.right); + ret = compute_two_tree_branch_stat(self, &l_state, &r_state, f, f_params, + result_dim, state_dim, result_tmp); + if (ret != 0) { + goto out; + } + do { + result_row = GET_2D_ROW(result, result_dim, bin_l); + bincount_row = GET_2D_ROW(bincount, result_dim, bin_l); + for (k = 0; k < result_dim; k++) { + result_row[k] += integrate_stat_over_window(ivl_l, ivl_r, bounds, + bins[bin_l], bins[bin_l + 1], result_tmp[k]); + bincount_row[k] += 1; + } + if (bin_l == bin_r) { + printf("%lu\t%lu\t%.14f\t%f\t%f\t== EQ ==\n", i, j, result_tmp[0], + bins[bin_l], bins[bin_l + 1]); + } else { + printf("%lu\t%lu\t%.14f\t%f\t%f\n", i, j, result_tmp[0], bins[bin_l], + bins[bin_l + 1]); + } + bin_l++; + } while (bin_l < bin_r); + } + } + for (i = 0; i < (num_bins - 1) * result_dim; i++) { + result[i] /= bincount[i]; + } + // printf("bins = { "); + // for (i = 0; i < num_bins - 1; i++) { + // printf("%f, ", bins[i]); + // } + // printf("%f }\n", bins[i]); + // printf("bincount = { "); + // for (i = 0; i < num_bins - 2; i++) { + // printf("%lu, ", bincount[i]); + // } + // printf("%lu }\n", bincount[i]); +out: + tsk_safe_free(result_tmp); + tsk_safe_free(bincount); + iter_state_free(&l_state); + iter_state_free(&r_state); + tsk_bitset_free(&node_samples); + tsk_bitset_free(&sample_sets_bits); + return ret; +} + +static int +tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + sample_count_stat_params_t *f_params, norm_func_t *norm_f, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + int ret = 0; + tsk_bitset_t allele_samples, allele_sample_sets; + bool polarised = options & TSK_STAT_POLARISED; + tsk_id_t *sites; + tsk_size_t i, j, k, bin, n_sites, *bincount_row; + double dist, *result_row, *result_tmp = NULL; + const tsk_size_t num_samples = self->num_samples; + const double *restrict site_position = self->tables->sites.position; + tsk_size_t *bincount = NULL, *num_alleles = NULL, *site_offsets = NULL, + *allele_counts = NULL; + tsk_size_t max_ss_size = 0, max_alleles = 0, n_alleles = 0; + two_locus_work_t work; + + tsk_memset(&allele_samples, 0, sizeof(allele_samples)); + n_sites = self->tables->sites.num_rows; + sites = tsk_malloc(n_sites * sizeof(*sites)); + if (sites == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + for (i = 0; i < n_sites; i++) { + sites[i] = (tsk_id_t) i; + } + // depends on n_sites + num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles)); + site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets)); + result_tmp = tsk_calloc(sizeof(*result_tmp), result_dim); + bincount = tsk_calloc(sizeof(*bincount), num_bins * result_dim); + if (num_alleles == NULL || site_offsets == NULL || result_tmp == NULL + || bincount == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + for (i = 0; i < n_sites; i++) { + site_offsets[i] = n_alleles * num_sample_sets; + n_alleles += self->site_mutations_length[sites[i]] + 1; + max_alleles = TSK_MAX(self->site_mutations_length[sites[i]], max_alleles); + } + max_alleles++; // add 1 for the ancestral allele + // depends on n_alleles + ret = tsk_bitset_init(&allele_samples, num_samples, n_alleles); + if (ret != 0) { + goto out; + } + for (i = 0; i < num_sample_sets; i++) { + max_ss_size = TSK_MAX(sample_set_sizes[i], max_ss_size); + } + // depend on n_alleles and max_ss_size + ret = tsk_bitset_init(&allele_sample_sets, max_ss_size, n_alleles * num_sample_sets); + if (ret != 0) { + goto out; + } + allele_counts = tsk_calloc(n_alleles * num_sample_sets, sizeof(*allele_counts)); + if (allele_counts == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + // depends on max_ss_size and max_alleles + ret = two_locus_work_init(max_alleles, max_ss_size, result_dim, state_dim, &work); + if (ret != 0) { + goto out; + } + // we track the number of alleles to account for backmutations + ret = get_mutation_samples(self, sites, n_sites, num_alleles, &allele_samples); + if (ret != 0) { + goto out; + } + get_mutation_sample_sets(&allele_samples, num_sample_sets, sample_set_sizes, + sample_sets, self->sample_index_map, &allele_sample_sets, allele_counts); + for (i = 0; i < n_sites; i++) { + for (j = i + 1; j < n_sites; j++) { + dist = site_position[j] - site_position[i]; + if (dist > bins[num_bins - 1]) { + break; + } + // if (bins[0] <= dist) { + // continue; + // } + bin = tsk_search_sorted(bins + 1, num_bins - 1, dist); + result_row = GET_2D_ROW(result, result_dim, bin); + bincount_row = GET_2D_ROW(bincount, result_dim, bin); + if (num_alleles[i] == 2 && num_alleles[j] == 2) { + // both sites are biallelic + ret = compute_general_two_site_stat_result(&allele_sample_sets, + allele_counts, site_offsets[i], site_offsets[j], state_dim, + result_dim, f, f_params, &work, &(result_row[j * result_dim])); + } else { + // at least one site is multiallelic + ret = compute_general_normed_two_site_stat_result(&allele_sample_sets, + allele_counts, site_offsets[i], site_offsets[j], num_alleles[i], + num_alleles[j], state_dim, result_dim, f, f_params, norm_f, + polarised, &work, result_tmp); + } + if (ret != 0) { + goto out; + } + for (k = 0; k < result_dim; k++) { + if (tsk_isnan(result_tmp[k])) { + continue; + } + result_row[k] += result_tmp[k]; + bincount_row[k] += 1; + } + tsk_memset(result_tmp, 0, sizeof(*result_tmp) * result_dim); + } + } + for (i = 0; i < (num_bins - 1) * result_dim; i++) { + result[i] /= bincount[i]; + } + // for (i = 0; i < num_bins - 1; i++) { + // result_row = GET_2D_ROW(result, result_dim, i); + // bincount_row = GET_2D_ROW(bincount, result_dim, i); + // for (k = 0; k < result_dim; k++) { + // result_row[k] /= (double) bincount_row[k]; + // } + // } +out: + tsk_safe_free(sites); + tsk_safe_free(bincount); + tsk_safe_free(result_tmp); + tsk_safe_free(num_alleles); + tsk_safe_free(site_offsets); + tsk_safe_free(allele_counts); + two_locus_work_free(&work); + tsk_bitset_free(&allele_samples); + tsk_bitset_free(&allele_sample_sets); + return ret; +} + +static int +tsk_treeseq_two_locus_decay_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, + norm_func_t *norm_f, const double *bins, tsk_size_t num_bins, tsk_flags_t options, + double *result) +{ + int ret = 0; + bool stat_site = !!(options & TSK_STAT_SITE); + bool stat_branch = !!(options & TSK_STAT_BRANCH); + tsk_size_t state_dim = num_sample_sets; + sample_count_stat_params_t f_params = { .sample_sets = sample_sets, + .num_sample_sets = num_sample_sets, + .sample_set_sizes = sample_set_sizes, + .set_indexes = set_indexes }; + + // We do not support two-locus node stats + if (!!(options & TSK_STAT_NODE)) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_STAT_MODE); + goto out; + } + // If no mode is specified, we default to site mode + if (!(stat_site || stat_branch)) { + stat_site = true; + } + // It's an error to specify more than one mode + if (stat_site + stat_branch > 1) { + ret = tsk_trace_error(TSK_ERR_MULTIPLE_STAT_MODES); + goto out; + } + ret = tsk_treeseq_check_sample_sets( + self, num_sample_sets, sample_set_sizes, sample_sets); + if (ret != 0) { + goto out; + } + if (result_dim < 1) { + ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS); + goto out; + } + if (stat_site) { + ret = tsk_treeseq_two_locus_site_decay_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, bins, + num_bins, options, result); + } else if (stat_branch) { + ret = tsk_treeseq_two_locus_branch_decay_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, bins, num_bins, + result); + goto out; + } else { + ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + goto out; + } +out: + return ret; +} + /*********************************** * Allele frequency spectrum ***********************************/ @@ -4303,13 +4666,24 @@ tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result) { - options |= TSK_STAT_POLARISED; // TODO: allow user to pick? + options |= TSK_STAT_POLARISED; return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_summary_func, norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, result); } +int +tsk_treeseq_D_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + options |= TSK_STAT_POLARISED; // TODO: allow user to pick? + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D_summary_func, norm_total_weighted, bins, + num_bins, options, result); +} + static int D2_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4348,6 +4722,16 @@ tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, result); } +int +tsk_treeseq_D2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D2_summary_func, norm_total_weighted, bins, + num_bins, options, result); +} + static int r2_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4387,6 +4771,16 @@ tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, row_sites, row_positions, num_cols, col_sites, col_positions, options, result); } +int +tsk_treeseq_r2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, r2_summary_func, norm_hap_weighted, bins, + num_bins, options, result); +} + static int D_prime_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4424,13 +4818,24 @@ tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result) { - options |= TSK_STAT_POLARISED; // TODO: allow user to pick? + options |= TSK_STAT_POLARISED; return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_prime_summary_func, norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, result); } +int +tsk_treeseq_D_prime_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + options |= TSK_STAT_POLARISED; + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D_prime_summary_func, norm_total_weighted, + bins, num_bins, options, result); +} + static int r_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4465,13 +4870,24 @@ tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result) { - options |= TSK_STAT_POLARISED; // TODO: allow user to pick? + options |= TSK_STAT_POLARISED; return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, r_summary_func, norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, result); } +int +tsk_treeseq_r_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + options |= TSK_STAT_POLARISED; + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, r_summary_func, norm_total_weighted, bins, + num_bins, options, result); +} + static int Dz_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4511,6 +4927,16 @@ tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, result); } +int +tsk_treeseq_Dz_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, Dz_summary_func, norm_total_weighted, bins, + num_bins, options, result); +} + static int pi2_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4547,6 +4973,16 @@ tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, result); } +int +tsk_treeseq_pi2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, pi2_summary_func, norm_total_weighted, bins, + num_bins, options, result); +} + static int D2_unbiased_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4584,6 +5020,16 @@ tsk_treeseq_D2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, col_positions, options, result); } +int +tsk_treeseq_D2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D2_unbiased_summary_func, + norm_total_weighted, bins, num_bins, options, result); +} + static int Dz_unbiased_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4622,6 +5068,16 @@ tsk_treeseq_Dz_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, col_positions, options, result); } +int +tsk_treeseq_Dz_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, Dz_unbiased_summary_func, + norm_total_weighted, bins, num_bins, options, result); +} + static int pi2_unbiased_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4660,6 +5116,16 @@ tsk_treeseq_pi2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, col_positions, options, result); } +int +tsk_treeseq_pi2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, pi2_unbiased_summary_func, + norm_total_weighted, bins, num_bins, options, result); +} + /*********************************** * Two way stats ***********************************/ diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 21495edbf7..6185d69507 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -118,12 +118,14 @@ typedef struct { tsk_table_collection_t *tables; } tsk_treeseq_t; +typedef struct { + double left; + double right; +} interval_t; + typedef struct { tsk_id_t index; - struct { - double left; - double right; - } interval; + interval_t interval; struct { tsk_id_t start; tsk_id_t stop; @@ -1135,6 +1137,42 @@ typedef int k_way_two_locus_count_stat_method(const tsk_treeseq_t *self, const double *row_positions, tsk_size_t num_cols, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result); +typedef int two_locus_decay_stat_method(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, const double *bins, tsk_size_t num_bins, + tsk_flags_t options, double *result); + +int tsk_treeseq_D_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); +int tsk_treeseq_D2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); +int tsk_treeseq_r2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); +int tsk_treeseq_D_prime_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); +int tsk_treeseq_r_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); +int tsk_treeseq_Dz_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); +int tsk_treeseq_pi2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); +int tsk_treeseq_D2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); +int tsk_treeseq_Dz_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); +int tsk_treeseq_pi2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); + /* Two way sample set stats */ int tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 78cb9f7c8e..c54ad80df7 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -990,6 +990,12 @@ bool_array_converter(PyObject *py_obj, PyArrayObject **array_out) return array_converter(NPY_BOOL, py_obj, array_out); } +static int +float64_array_converter(PyObject *py_obj, PyArrayObject **array_out) +{ + return array_converter(NPY_FLOAT64, py_obj, array_out); +} + /* Note: it doesn't seem to be possible to cast pointers to the actual * table functions to this type because the first argument must be a * void *, so the simplest option is to put in a small shim that @@ -8105,6 +8111,127 @@ TreeSequence_r2_ij_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) return TreeSequence_k_way_ld_matrix(self, args, kwds, 2, tsk_treeseq_r2_ij); } +static PyObject * +TreeSequence_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, + two_locus_decay_stat_method *method) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "bins", "mode", NULL }; + + PyObject *sample_sets = NULL, *sample_set_sizes = NULL; + PyArrayObject *sample_sets_array = NULL, *sample_set_sizes_array = NULL, + *bins = NULL, *result_matrix = NULL; + npy_intp num_bins, result_dim[2]; + char *mode = NULL; + tsk_size_t num_sample_sets; + tsk_flags_t options = 0; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO&|s", kwlist, &sample_set_sizes, + &sample_sets, &float64_array_converter, &bins, &mode)) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; + } + num_bins = PyArray_DIM(bins, 0); + result_dim[0] = num_bins - 1; + result_dim[1] = num_sample_sets; + result_matrix = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_FLOAT64, 0); + if (result_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + // clang-format off + Py_BEGIN_ALLOW_THREADS + err = method(self->tree_sequence, num_sample_sets, + PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), + PyArray_DATA(bins), num_bins, options, PyArray_DATA(result_matrix)); + Py_END_ALLOW_THREADS + // clang-format on + if (err != 0) + { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_matrix; + result_matrix = NULL; +out: + Py_XDECREF(bins); + Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(result_matrix); + return ret; +} + +static PyObject * +TreeSequence_D_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_D_decay); +} + +static PyObject * +TreeSequence_D2_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_D2_decay); +} + +static PyObject * +TreeSequence_r2_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_r2_decay); +} + +static PyObject * +TreeSequence_D_prime_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_D_prime_decay); +} + +static PyObject * +TreeSequence_r_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_r_decay); +} + +static PyObject * +TreeSequence_Dz_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_Dz_decay); +} + +static PyObject * +TreeSequence_pi2_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_pi2_decay); +} + +static PyObject * +TreeSequence_pi2_unbiased_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_pi2_unbiased_decay); +} + +static PyObject * +TreeSequence_D2_unbiased_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_D2_unbiased_decay); +} + +static PyObject * +TreeSequence_Dz_unbiased_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_Dz_unbiased_decay); +} + static PyObject * TreeSequence_get_num_mutations(TreeSequence *self) { @@ -8800,6 +8927,46 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_pi2_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the pi2 matrix." }, + { .ml_name = "D_decay", + .ml_meth = (PyCFunction) TreeSequence_D_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the D decay curve." }, + { .ml_name = "D2_decay", + .ml_meth = (PyCFunction) TreeSequence_D2_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the D2 decay curve." }, + { .ml_name = "r2_decay", + .ml_meth = (PyCFunction) TreeSequence_r2_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the r2 decay curve." }, + { .ml_name = "D_prime_decay", + .ml_meth = (PyCFunction) TreeSequence_D_prime_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the D_prime decay curve." }, + { .ml_name = "r_decay", + .ml_meth = (PyCFunction) TreeSequence_r_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the r decay curve." }, + { .ml_name = "Dz_decay", + .ml_meth = (PyCFunction) TreeSequence_Dz_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the Dz decay curve." }, + { .ml_name = "pi2_decay", + .ml_meth = (PyCFunction) TreeSequence_pi2_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the pi2 decay curve." }, + { .ml_name = "D2_unbiased_decay", + .ml_meth = (PyCFunction) TreeSequence_D2_unbiased_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased D2 decay curve." }, + { .ml_name = "Dz_unbiased_decay", + .ml_meth = (PyCFunction) TreeSequence_Dz_unbiased_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased Dz decay curve." }, + { .ml_name = "pi2_unbiased_decay", + .ml_meth = (PyCFunction) TreeSequence_pi2_unbiased_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased pi2 decay curve." }, { .ml_name = "D2_unbiased_matrix", .ml_meth = (PyCFunction) TreeSequence_D2_unbiased_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 9904b60a98..58fbe32092 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -23,6 +23,7 @@ """ Module responsible for managing trees and tree sequences. """ + from __future__ import annotations import base64 @@ -696,8 +697,7 @@ def __init__( options = 0 if sample_counts is not None: warnings.warn( - "The sample_counts option is not supported since 0.2.4 " - "and is ignored", + "The sample_counts option is not supported since 0.2.4 and is ignored", RuntimeWarning, stacklevel=4, ) @@ -6889,7 +6889,7 @@ def to_macs(self): bytes_genotypes[:] = lookup[variant.genotypes] genotypes = bytes_genotypes.tobytes().decode() output.append( - f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t" f"{genotypes}" + f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t{genotypes}" ) return "\n".join(output) + "\n" @@ -8331,6 +8331,49 @@ def __k_way_sample_set_stat( stat = stat[()] return stat + def _try_drop_dimension(self, sample_sets): + # First try to convert to a 1D numpy array. If we succeed, then we strip off + # the corresponding dimension from the output. + drop_dimension = False + try: + sample_sets = np.array(sample_sets, dtype=np.uint64) + except ValueError: + pass + else: + # If we've successfully converted sample_sets to a 1D numpy array + # of integers then drop the dimension + if len(sample_sets.shape) == 1: + sample_sets = [sample_sets] + drop_dimension = True + return sample_sets, drop_dimension + + def __two_locus_sample_set_decay_stat( + self, + ll_method, + sample_sets, + bins, + mode=None, + ): + if sample_sets is None: + sample_sets = self.samples() + + sample_sets, drop_dimension = self._try_drop_dimension(sample_sets) + sample_set_sizes = np.array( + [len(sample_set) for sample_set in sample_sets], dtype=np.uint32 + ) + if np.any(sample_set_sizes == 0): + raise ValueError("Sample sets must contain at least one element") + + flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + result = ll_method(sample_set_sizes, flattened, bins, mode) + if drop_dimension: + result = result.reshape(result.shape[0]) + else: + # Orient the data so that the first dimension is the sample set. + result = result.swapaxes(0, 1) + + return result + def __k_way_weighted_stat( self, ll_method, @@ -9281,9 +9324,9 @@ def pca( if time_windows is None: tree_sequence_low, tree_sequence_high = None, self else: - assert ( - time_windows[0] < time_windows[1] - ), "The second argument should be larger." + assert time_windows[0] < time_windows[1], ( + "The second argument should be larger." + ) tree_sequence_low, tree_sequence_high = ( self.decapitate(time_windows[0]), self.decapitate(time_windows[1]), @@ -9351,9 +9394,9 @@ def _rand_pow_range_finder( """ Algorithm 9 in https://arxiv.org/pdf/2002.01387 """ - assert ( - num_vectors >= rank > 0 - ), "num_vectors should not be smaller than rank" + assert num_vectors >= rank > 0, ( + "num_vectors should not be smaller than rank" + ) for _ in range(depth): Q = np.linalg.qr(Q)[0] Q = operator(Q) @@ -10880,6 +10923,32 @@ def ld_matrix( stat_func, sample_sets, sites=sites, positions=positions, mode=mode ) + def ld_decay(self, bins, sample_sets=None, mode="site", stat="r2"): + stats = { + "D": self._ll_tree_sequence.D_decay, + "D2": self._ll_tree_sequence.D2_decay, + "r2": self._ll_tree_sequence.r2_decay, + "D_prime": self._ll_tree_sequence.D_prime_decay, + "r": self._ll_tree_sequence.r_decay, + "Dz": self._ll_tree_sequence.Dz_decay, + "pi2": self._ll_tree_sequence.pi2_decay, + "Dz_unbiased": self._ll_tree_sequence.Dz_unbiased_decay, + "D2_unbiased": self._ll_tree_sequence.D2_unbiased_decay, + "pi2_unbiased": self._ll_tree_sequence.pi2_unbiased_decay, + } + try: + stat_func = stats[stat] + except KeyError: + raise ValueError( + f"Unknown two-locus statistic '{stat}', we support: {list(stats.keys())}" + ) + return self.__two_locus_sample_set_decay_stat( + stat_func, + sample_sets, + bins=bins, + mode=mode, + ) + def sample_nodes_by_ploidy(self, ploidy): """ Returns an 2D array of node IDs, where each row has length `ploidy`. From 594386aa9382dd77c5175efca5f55fc6a1678447 Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 7 Nov 2025 00:00:07 -0600 Subject: [PATCH 3/7] finalize algo, we get all of the LD density now. --- c/tskit/trees.c | 101 ++++--------- python/tests/test_ld_decay.py | 261 ++++++++++++++++++++++++++++++++++ 2 files changed, 292 insertions(+), 70 deletions(-) create mode 100644 python/tests/test_ld_decay.py diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 675d6244b1..2e0e8c83e9 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3461,59 +3461,33 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl return ret; } -/* Find the distance span between two trees (ie the greatest and smallest distance that - * can be represented by the subtraction of two tree's intervals. - * NB: required condition: ivl_1 <= ivl_2 */ -static inline void -get_distance_bounds(const interval_t i1, const interval_t i2, interval_t *out) -{ - // Equal intervals will stretch into the negative. - // TODO: any others? If not, the following will suffice. - // // If the intervals are equal - // if (i1.left == i2.left && i1.right == i2.right) { - // out->left = 0; - // out->right = i1.right; - // return; - // } - out->left = fmax(0, i2.left - i1.right); - out->right = i2.right - i1.left; -} - // TODO: breaks at 0,1/0,1 -- span is into negative, no need to 1/2 stat. +// TODO: support is bound or as written in test? static double -integrate_stat_over_window(const interval_t i1, const interval_t i2, - const interval_t bounds, double wl, double wr, double stat) +integrate_stat_over_bin( + const interval_t i1, const interval_t i2, double bl, double br, double stat) { + interval_t support = { i2.left - i1.right, i2.right - i1.left }; double r2_len = fmin(i1.right - i1.left, i2.right - i2.left); // Size of the center region is determined by the larger of the two // intervals. It is zero if they are equal (triangle). double r2_l_bound = fmin(i2.left - i1.left, i2.right - i1.right); - double r2_r_bound = bounds.right - r2_len; + double r2_r_bound = support.right - r2_len; // left and right values for each of the 3 regions to integrate over // variable names are: r{region}_{left|right} - double r1_l = fmin(fmax(wl, bounds.left), r2_l_bound); - double r1_r = fmax(fmin(wr, r2_l_bound), bounds.left); - double r2_l = fmin(fmax(wl, r2_l_bound), r2_r_bound); - double r2_r = fmax(fmin(wr, r2_r_bound), r2_l_bound); - double r3_l = fmin(fmax(wl, r2_r_bound), bounds.right); - double r3_r = fmax(fmin(wr, bounds.right), r2_r_bound); + double r1_l = fmin(fmax(bl, support.left), r2_l_bound); + double r1_r = fmax(fmin(br, r2_l_bound), support.left); + double r2_l = fmin(fmax(bl, r2_l_bound), r2_r_bound); + double r2_r = fmax(fmin(br, r2_r_bound), r2_l_bound); + double r3_l = fmin(fmax(bl, r2_r_bound), support.right); + double r3_r = fmax(fmin(br, support.right), r2_r_bound); double i1_span = i1.right - i1.left; double i2_span = i2.right - i2.left; - // double s = (stat / i1_span) * (stat / i2_span) - // * (-.5 * (r1_l - r1_r) * (2. * i1.right - 2. * i2.left + r1_l + r1_r) - // + (r2_r - r2_l) * r2_len - // + .5 * (r3_l - r3_r) * (2 * i1.left - 2. * i2.right + r3_l + - // r3_r)); - - double s - = (stat / i1_span) * (stat / i2_span) - * (-1. / 2 * (r1_l - r1_r) * (2 * i1.right - 2 * i2.left + r1_l + r1_r) - + (r2_r - r2_l) * r2_len - + 1. / 2 * (r3_l - r3_r) * (2 * i1.left - 2 * i2.right + r3_l + r3_r)); - // printf("%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\t%f\n", s, r2_len, r2_l_bound, - // r2_r_bound, r1_l, r1_r, r2_l, r2_r, r3_l, r3_r, i1_span, i2_span); - // printf("%f\n", s); - return s; + + return stat / (i1_span * i2_span) + * (-1. / 2 * (r1_l - r1_r) * (2 * i1.right - 2 * i2.left + r1_l + r1_r) + + (r2_r - r2_l) * r2_len + + 1. / 2 * (r3_l - r3_r) * (2 * i1.left - 2 * i2.right + r3_l + r3_r)); } static int @@ -3559,13 +3533,9 @@ tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t st goto out; } iter_state_clear(&l_state, state_dim, num_nodes, &node_samples); - // TODO: bin skipping based on range for (i = 0; i < self->num_trees; i++) { tsk_memset(result_tmp, 0, result_dim * sizeof(*result_tmp)); iter_state_clear(&r_state, state_dim, num_nodes, &node_samples); - // TODO: check distance and continue if too short - // TODO: verify my assumptions that this is merely a setup step and we can - // discard the stat value ret = advance_collect_edges(&l_state, (tsk_id_t) i); if (ret != 0) { goto out; @@ -3582,7 +3552,11 @@ tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t st goto out; } ivl_r = r_state.tree.tree_pos.interval; - get_distance_bounds(ivl_l, ivl_r, &bounds); + bounds = (interval_t){ fmax(0, ivl_r.left - ivl_l.right), + fmin(bins[num_bins - 1], ivl_r.right - ivl_l.left) }; + if (bounds.left > bins[num_bins - 1] || bounds.right < bins[0]) { + continue; + } bin_l = tsk_search_sorted(bins + 1, num_bins - 1, bounds.left); bin_r = tsk_search_sorted(bins + 1, num_bins - 1, bounds.right); ret = compute_two_tree_branch_stat(self, &l_state, &r_state, f, f_params, @@ -3594,29 +3568,23 @@ tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t st result_row = GET_2D_ROW(result, result_dim, bin_l); bincount_row = GET_2D_ROW(bincount, result_dim, bin_l); for (k = 0; k < result_dim; k++) { - result_row[k] += integrate_stat_over_window(ivl_l, ivl_r, bounds, - bins[bin_l], bins[bin_l + 1], result_tmp[k]); + // double s = integrate_stat_over_window(ivl_l, ivl_r, bounds, + // bins[bin_l], bins[bin_l + 1], result_tmp[k]); + // printf("%lu\t%lu\t%.15f\t%.15f\t%.15f\t%.15f\t%.15f\t%.15f\n", i, + // j, + // bounds.left, bounds.right, bins[bin_l], bins[bin_l + 1], + // result_tmp[k], s); + result_row[k] += integrate_stat_over_bin( + ivl_l, ivl_r, bins[bin_l], bins[bin_l + 1], result_tmp[k]); bincount_row[k] += 1; } - if (bin_l == bin_r) { - printf("%lu\t%lu\t%.14f\t%f\t%f\t== EQ ==\n", i, j, result_tmp[0], - bins[bin_l], bins[bin_l + 1]); - } else { - printf("%lu\t%lu\t%.14f\t%f\t%f\n", i, j, result_tmp[0], bins[bin_l], - bins[bin_l + 1]); - } bin_l++; - } while (bin_l < bin_r); + } while (bin_l <= bin_r); } } for (i = 0; i < (num_bins - 1) * result_dim; i++) { result[i] /= bincount[i]; } - // printf("bins = { "); - // for (i = 0; i < num_bins - 1; i++) { - // printf("%f, ", bins[i]); - // } - // printf("%f }\n", bins[i]); // printf("bincount = { "); // for (i = 0; i < num_bins - 2; i++) { // printf("%lu, ", bincount[i]); @@ -3724,7 +3692,7 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat // both sites are biallelic ret = compute_general_two_site_stat_result(&allele_sample_sets, allele_counts, site_offsets[i], site_offsets[j], state_dim, - result_dim, f, f_params, &work, &(result_row[j * result_dim])); + result_dim, f, f_params, &work, result_tmp); } else { // at least one site is multiallelic ret = compute_general_normed_two_site_stat_result(&allele_sample_sets, @@ -3748,13 +3716,6 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat for (i = 0; i < (num_bins - 1) * result_dim; i++) { result[i] /= bincount[i]; } - // for (i = 0; i < num_bins - 1; i++) { - // result_row = GET_2D_ROW(result, result_dim, i); - // bincount_row = GET_2D_ROW(bincount, result_dim, i); - // for (k = 0; k < result_dim; k++) { - // result_row[k] /= (double) bincount_row[k]; - // } - // } out: tsk_safe_free(sites); tsk_safe_free(bincount); diff --git a/python/tests/test_ld_decay.py b/python/tests/test_ld_decay.py new file mode 100644 index 0000000000..4dd1ba92e1 --- /dev/null +++ b/python/tests/test_ld_decay.py @@ -0,0 +1,261 @@ +import contextlib +from itertools import combinations +from itertools import combinations_with_replacement +from itertools import product + +import demes +import msprime +import numpy as np +import pytest + +from tskit import Interval + + +@contextlib.contextmanager +def suppress_overflow_div0_warning(): + with np.errstate(over="ignore", invalid="ignore", divide="ignore"): + yield + + +def expand_dims(arr): + """ + Expand the dimensions of the provided array (arrays). This helps to control + the output dimensions of the ld matrix (ie if there's 2 dimensions to + indexes or sample_sets, we'll get a 3D ld matrix back. This will not be + necessary in the C implementation because dimension dropping happens in the + python layer. + """ + try: + arr = np.asarray(arr) + if arr.ndim == 1: + return np.expand_dims(arr, axis=0) + except: + pass + try: + arr = [np.asarray(a) for a in arr] + except Exception as e: + raise ValueError("Must be a list of 1D array-like") from e + for a in arr: + if a.ndim != 1: + raise ValueError("Must be a list of 1D arrays") + return arr + + +def check_bins(bins, seq_len): + try: + bins = np.asarray(bins) + except Exception as e: + raise ValueError("Bins must be coercible to a 1D array") from e + if bins.ndim != 1: + raise ValueError("Bins must be a 1D array") + if not np.all(bins[:-1] <= bins[1:]): + raise ValueError("Bins must be sorted") + if bins[-1] > seq_len: + raise ValueError(f"Last bin is out of bounds, must be <= L: {bins[-1]}") + if len(bins) < 2: + raise ValueError(f"Must have at least 2 bins, got {len(bins)}") + if (bins < 0).any(): + raise ValueError("Bins must be greater than 0") + return bins + + +def construct_ld_matrix(ts, stat, sample_sets, indexes): + """ + Produce an ld matrix with the same error characteristics as the C version + Create an LD matrix by starting at the diagonal of each row. This ensures + that we accumulate error in the same way as we would in the C version. If + we produce an LD matrix starting from tree 0 at each row, we accumulate a + different (likely more) amount of error. + """ + bp = ts.breakpoints(as_array=True)[:-1] + # TODO: output dims + out = np.zeros((1, ts.num_trees, ts.num_trees)) + for i, b in enumerate(bp): + out[0, i, i:] = ts.ld_matrix(mode="branch", stat=stat, positions=[[b], bp[i:]]) + return out + + +def integrate_stat_over_bin(bin, i1, i2, stat): + bl, br = bin + # Integration support + l_support = i2.left - i1.right + r_support = i2.right - i1.left + # length of the middle region + r2_len = min(i1.right - i1.left, i2.right - i2.left) + # bounds of the middle region + r2_l_bound = min(i2.left - i1.left, i2.right - i1.right) + r2_r_bound = r_support - r2_len + + r1_l = min(max(bl, l_support), r2_l_bound) + r1_r = max(min(br, r2_l_bound), l_support) + r2_l = min(max(bl, r2_l_bound), r2_r_bound) + r2_r = max(min(br, r2_r_bound), r2_l_bound) + r3_l = min(max(bl, r2_r_bound), r_support) + r3_r = max(min(br, r_support), r2_r_bound) # this one differs from mm nb + return ( + stat + / (i1.span * i2.span) + * ( + -1 / 2 * (r1_l - r1_r) * (2 * i1.right - 2 * i2.left + r1_l + r1_r) + + (r2_r - r2_l) * r2_len + + 1 / 2 * (r3_l - r3_r) * (2 * i1.left - 2 * i2.right + r3_l + r3_r) + ) + ) + + +def isect(l1, r1, l2, r2): + """Right closed left open""" + return max(l1, l2) < min(r1, r2) or l1 == r2 or l2 == r1 + + +def get_tree_pair_bounds(ivl_l, ivl_r, bins): + return Interval( + max(0, ivl_r.left - ivl_l.right), + min(bins[-1], ivl_r.right - ivl_l.left), + ) + + +def ld_decay_branch(ts, bins, stat, sample_sets, indexes): + ld = construct_ld_matrix(ts, stat, sample_sets, indexes) + dims = (len(indexes or sample_sets), len(bins) - 1) + result = np.zeros(dims, dtype=float) + bincount = np.zeros(dims, dtype=int) + bp = ts.breakpoints(as_array=True) + bin_ivls = np.fromiter(zip(bins[:-1], bins[1:]), np.dtype((float, 2))) + for i, j in combinations_with_replacement(range(ts.num_trees), 2): # upper tri+diag + ivl_l = Interval(bp[i], bp[i + 1]) + ivl_r = Interval(bp[j], bp[j + 1]) + bounds = get_tree_pair_bounds(ivl_l, ivl_r, bins) + for k in range(dims[0]): + # for b in bin_ivls: + # if isect(*bounds, *b): + # s = integrate_stat_over_bin(b, ivl_l, ivl_r, ld[k, i, j]) + # print( + # f"{i}\t{j}\t" + # f"{bounds.left:.15f}\t{bounds.right:.15f}\t" + # f"{b[0]:.15f}\t{b[1]:.15f}\t" + # f"{ld[k, i, j]:.15f}\t" + # f"{s:.15f}" + # ) + result[k] += np.apply_along_axis( + integrate_stat_over_bin, 1, bin_ivls, ivl_l, ivl_r, ld[k, i, j] + ) + bincount[k] += np.fromiter((isect(*bounds, *b) for b in bin_ivls), int) + if dims[0] == 1: # drop dims if first dim is length 1 + return result.reshape(dims[1:]), bincount.reshape(dims[1:]) + return result, bincount + + +def ld_decay_site(ts, bins, stat, sample_sets, indexes): + ld = ts.ld_matrix(stat=stat, sample_sets=sample_sets, indexes=indexes) + dims = (len(indexes or sample_sets), len(bins) - 1) + result = np.zeros(dims, dtype=float) + bincount = np.zeros(dims, dtype=int) + site_pos = ts.sites_position + for i, j in combinations(range(ts.num_sites), 2): # upper tri-diag + dist = site_pos[j] - site_pos[i] + if dist > bins[-1]: + break + bin = np.searchsorted(bins[1:], dist, side="left") + # if bin == 3 and np.isnan(ld[:, i, j]).any(): + # breakpoint() + for k in range(dims[0]): + s = ld[k, i, j] + if np.isnan(s): + continue + result[k, bin] += s + bincount[k, bin] += 1 + if dims[0] == 1: # drop dims if first dim is length 1 + return result.reshape(dims[1:]), bincount.reshape(dims[1:]) + return result, bincount + + +def ld_decay( + ts, + bins, + stat="r2", + sample_sets=None, + indexes=None, + mode="site", + return_counts=False, +): + bins = check_bins(bins, ts.sequence_length) + sample_sets = expand_dims(sample_sets or [ts.samples()]) + if indexes is not None: + indexes = expand_dims(indexes) + match mode: + case "site": + result, count = ld_decay_site(ts, bins, stat, sample_sets, indexes) + case "branch": + result, count = ld_decay_branch(ts, bins, stat, sample_sets, indexes) + case _: + raise ValueError(f"Unknown Stats Mode: {mode}") + if return_counts: + return result, count + with suppress_overflow_div0_warning(): + return result / count + + +ONE_WAY_STATS = [ + "r", + "r2", + "D", + "D2", + "D_prime", + "pi2", + "Dz", + "D2_unbiased", + "Dz_unbiased", + "pi2_unbiased", +] + +TS = msprime.sim_mutations( + msprime.sim_ancestry( + samples=100, + sequence_length=1e5, + recombination_rate=1e-8, + demography=msprime.Demography.from_demes( + demes.loads(""" + time_units: generations + demes: + - name: A + epochs: + - {start_size: 5000, end_time: 1000} + - {start_size: 1000, end_time: 400} + - {start_size: 5000, end_time: 0} + """) + ), + random_seed=23, + ), + rate=1e-7, + random_seed=23, +) + + +@pytest.mark.parametrize("stat,mode", product(ONE_WAY_STATS, ["site", "branch"])) +def test_ld_decay(stat, mode): + bins = np.logspace(0, np.log10(TS.sequence_length), num=35) + bins[0] = 0 + decay, counts = ld_decay(TS, bins, stat=stat, mode=mode, return_counts=True) + c = TS.ld_decay(bins, stat=stat, mode=mode) + with suppress_overflow_div0_warning(): + np.testing.assert_array_equal(decay / counts, c) + # Verify that the sum of all LD in our bins is equal to the sum of the LD + # matrix entries from which they originated. + if mode == "branch": + tu = np.triu( + construct_ld_matrix( + TS, sample_sets=expand_dims(TS.samples()), indexes=None, stat=stat + ).squeeze() + ) + dmask = np.diag_indices_from(tu) + tu[dmask] = tu[dmask] / 2 # we take half the density on the diagonal + np.testing.assert_allclose(np.nansum(decay), np.nansum(tu)) + # all but r2 D2 Dz are within 1 ulp + np.testing.assert_array_almost_equal_nulp( + np.nansum(decay), np.nansum(tu), nulp=2 + ) + print(f"{stat} diff={np.nansum(decay) - np.nansum(tu)}") + elif mode == "site": + tu = TS.ld_matrix(stat=stat)[np.triu_indices(TS.num_sites, k=1)] + np.testing.assert_allclose(decay.sum(), np.nansum(tu)) From d53f45706aea60588738a18fb8db6208b178c922 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 18 Nov 2025 03:40:19 -0600 Subject: [PATCH 4/7] Fix issue with bins that do not span whole length Skip comparison if out of range, do not advance edges unless a statistic is to be computed. Also add two-way stats --- c/tskit/trees.c | 124 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 102 insertions(+), 22 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 2e0e8c83e9..dcd7305e5f 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3498,6 +3498,7 @@ tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t st double *result) { int ret = 0; + const double *restrict breakpoints = self->breakpoints; const tsk_size_t num_nodes = self->tables->nodes.num_rows; interval_t bounds, ivl_l, ivl_r; iter_state l_state, r_state; @@ -3547,18 +3548,18 @@ tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t st goto out; } for (j = i; j < self->num_trees; j++) { - ret = advance_collect_edges(&r_state, (tsk_id_t) j); - if (ret != 0) { - goto out; - } - ivl_r = r_state.tree.tree_pos.interval; + ivl_r = (interval_t){ breakpoints[j], breakpoints[j + 1] }; bounds = (interval_t){ fmax(0, ivl_r.left - ivl_l.right), fmin(bins[num_bins - 1], ivl_r.right - ivl_l.left) }; + bin_l = tsk_search_sorted(bins + 1, num_bins - 1, bounds.left); + bin_r = tsk_search_sorted(bins + 1, num_bins - 1, bounds.right); if (bounds.left > bins[num_bins - 1] || bounds.right < bins[0]) { continue; } - bin_l = tsk_search_sorted(bins + 1, num_bins - 1, bounds.left); - bin_r = tsk_search_sorted(bins + 1, num_bins - 1, bounds.right); + ret = advance_collect_edges(&r_state, (tsk_id_t) j); + if (ret != 0) { + goto out; + } ret = compute_two_tree_branch_stat(self, &l_state, &r_state, f, f_params, result_dim, state_dim, result_tmp); if (ret != 0) { @@ -3568,12 +3569,7 @@ tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t st result_row = GET_2D_ROW(result, result_dim, bin_l); bincount_row = GET_2D_ROW(bincount, result_dim, bin_l); for (k = 0; k < result_dim; k++) { - // double s = integrate_stat_over_window(ivl_l, ivl_r, bounds, - // bins[bin_l], bins[bin_l + 1], result_tmp[k]); - // printf("%lu\t%lu\t%.15f\t%.15f\t%.15f\t%.15f\t%.15f\t%.15f\n", i, - // j, - // bounds.left, bounds.right, bins[bin_l], bins[bin_l + 1], - // result_tmp[k], s); + // TODO: nansum?? result_row[k] += integrate_stat_over_bin( ivl_l, ivl_r, bins[bin_l], bins[bin_l + 1], result_tmp[k]); bincount_row[k] += 1; @@ -3585,11 +3581,14 @@ tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t st for (i = 0; i < (num_bins - 1) * result_dim; i++) { result[i] /= bincount[i]; } - // printf("bincount = { "); - // for (i = 0; i < num_bins - 2; i++) { - // printf("%lu, ", bincount[i]); + // for (k = 0; k < result_dim; k++) { + // bincount_row = GET_2D_ROW(bincount, num_bins - 1, k); + // printf("bincount%lu = { ", k); + // for (i = 0; i < num_bins - 2; i++) { + // printf("%lu, ", bincount_row[i]); + // } + // printf("%lu }\n", bincount_row[i]); // } - // printf("%lu }\n", bincount[i]); out: tsk_safe_free(result_tmp); tsk_safe_free(bincount); @@ -3633,8 +3632,8 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat // depends on n_sites num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles)); site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets)); - result_tmp = tsk_calloc(sizeof(*result_tmp), result_dim); - bincount = tsk_calloc(sizeof(*bincount), num_bins * result_dim); + result_tmp = tsk_malloc(result_dim * sizeof(*result_tmp)); + bincount = tsk_calloc(result_dim * (num_bins - 1), sizeof(*bincount)); if (num_alleles == NULL || site_offsets == NULL || result_tmp == NULL || bincount == NULL) { ret = tsk_trace_error(TSK_ERR_NO_MEMORY); @@ -3682,9 +3681,10 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat if (dist > bins[num_bins - 1]) { break; } - // if (bins[0] <= dist) { - // continue; - // } + // TODO: Very left interval is closed? + if (dist < bins[0]) { + continue; + } bin = tsk_search_sorted(bins + 1, num_bins - 1, dist); result_row = GET_2D_ROW(result, result_dim, bin); bincount_row = GET_2D_ROW(bincount, result_dim, bin); @@ -3709,13 +3709,39 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat } result_row[k] += result_tmp[k]; bincount_row[k] += 1; + // if (result_row[0] != result_row[k]) { + // printf("BAD %lu\t%f\t%f\n", k, result_row[0], result_row[k]); + // } } tsk_memset(result_tmp, 0, sizeof(*result_tmp) * result_dim); } } + // puts("result\t\tbincount"); for (i = 0; i < (num_bins - 1) * result_dim; i++) { + // printf("%f\t%lu\n", result[i], bincount[i]); result[i] /= bincount[i]; } + // puts(""); + // puts("======="); + // puts("bincount = {"); + // for (i = 0; i < num_bins - 1; i++) { + // bincount_row = GET_2D_ROW(bincount, result_dim, i); + // printf(" { "); + // for (k = 0; k < result_dim - 1; k++) { + // printf("%lu, ", bincount_row[k]); + // } + // printf("%lu },\n", bincount_row[k]); + // } + // puts("result = {"); + // for (i = 0; i < num_bins - 1; i++) { + // result_row = GET_2D_ROW(result, result_dim, i); + // printf(" { "); + // for (k = 0; k < result_dim - 1; k++) { + // printf("%f, ", result_row[k]); + // } + // printf("%f },\n", result_row[k]); + // } + // puts("}"); out: tsk_safe_free(sites); tsk_safe_free(bincount); @@ -5463,6 +5489,24 @@ tsk_treeseq_D2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, return ret; } +int +tsk_treeseq_D2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, D2_ij_summary_func, + norm_total_weighted, bins, num_bins, options, result); +out: + return ret; +} + static int D2_ij_unbiased_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, tsk_size_t result_dim, double *result, void *params) @@ -5537,6 +5581,24 @@ tsk_treeseq_D2_ij_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets return ret; } +int +tsk_treeseq_D2_ij_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, D2_ij_unbiased_summary_func, + norm_total_weighted, bins, num_bins, options, result); +out: + return ret; +} + static int r2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, tsk_size_t result_dim, double *result, void *params) @@ -5597,6 +5659,24 @@ tsk_treeseq_r2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, return ret; } +int +tsk_treeseq_r2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, r2_ij_summary_func, + norm_hap_weighted_ij, bins, num_bins, options, result); +out: + return ret; +} + /*********************************** * Three way stats ***********************************/ From d78ec208cf07e1d19220b99d52e5f7adad7173ed Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 18 Nov 2025 03:42:05 -0600 Subject: [PATCH 5/7] add two-way ld decay; keep debugging code --- c/tskit/trees.h | 20 ++++++ python/_tskitmodule.c | 128 ++++++++++++++++++++++++++++++++-- python/tests/test_ld_decay.py | 94 ++++++++++++++++++------- python/tskit/trees.py | 58 ++++++++++++++- 4 files changed, 265 insertions(+), 35 deletions(-) diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 6185d69507..eb4e68e30d 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1142,6 +1142,12 @@ typedef int two_locus_decay_stat_method(const tsk_treeseq_t *self, const tsk_id_t *sample_sets, const double *bins, tsk_size_t num_bins, tsk_flags_t options, double *result); +typedef int k_way_two_locus_decay_stat_method(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, const double *bins, tsk_size_t num_bins, + tsk_flags_t options, double *result); + int tsk_treeseq_D_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, tsk_size_t num_bins, tsk_flags_t options, double *result); @@ -1173,6 +1179,20 @@ int tsk_treeseq_pi2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sam const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, tsk_size_t num_bins, tsk_flags_t options, double *result); +int tsk_treeseq_D2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); +int tsk_treeseq_D2_ij_unbiased_decay(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, const double *bins, tsk_size_t num_bins, + tsk_flags_t options, double *result); +int tsk_treeseq_r2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, + tsk_size_t num_bins, tsk_flags_t options, double *result); + /* Two way sample set stats */ int tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index c54ad80df7..861ef05cf6 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -8117,13 +8117,16 @@ TreeSequence_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, { PyObject *ret = NULL; static char *kwlist[] = { "sample_set_sizes", "sample_sets", "bins", "mode", NULL }; - - PyObject *sample_sets = NULL, *sample_set_sizes = NULL; - PyArrayObject *sample_sets_array = NULL, *sample_set_sizes_array = NULL, - *bins = NULL, *result_matrix = NULL; - npy_intp num_bins, result_dim[2]; - char *mode = NULL; + PyObject *sample_sets = NULL; + PyObject *sample_set_sizes = NULL; + PyArrayObject *sample_sets_array = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + PyArrayObject *bins = NULL; + PyArrayObject *result_matrix = NULL; + npy_intp num_bins; + npy_intp result_dim[2]; tsk_size_t num_sample_sets; + char *mode = NULL; tsk_flags_t options = 0; int err; @@ -8232,6 +8235,107 @@ TreeSequence_Dz_unbiased_decay(TreeSequence *self, PyObject *args, PyObject *kwd return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_Dz_unbiased_decay); } +static PyObject * +TreeSequence_k_way_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, + npy_intp tuple_size, k_way_two_locus_decay_stat_method *method) +{ + PyObject *ret = NULL; + static char *kwlist[] + = { "sample_set_sizes", "sample_sets", "indexes", "bins", "mode", NULL }; + PyObject *sample_sets = NULL; + PyObject *sample_set_sizes = NULL; + PyObject *indexes = NULL; + PyArrayObject *bins = NULL; + PyArrayObject *indexes_array = NULL; + PyArrayObject *result_matrix = NULL; + PyArrayObject *sample_sets_array = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + npy_intp num_bins; + npy_intp *shape; + npy_intp result_dim[2]; + tsk_size_t num_sample_sets; + tsk_size_t num_set_index_tuples; + char *mode = NULL; + tsk_flags_t options = 0; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO&|s", kwlist, &sample_set_sizes, + &sample_sets, &indexes, &float64_array_converter, &bins, &mode)) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; + } + indexes_array = (PyArrayObject *) PyArray_FROMANY( + indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY); + if (indexes_array == NULL) { + goto out; + } + shape = PyArray_DIMS(indexes_array); + if (shape[0] < 1 || shape[1] != tuple_size) { + PyErr_Format( + PyExc_ValueError, "indexes must be a k x %d array.", (int) tuple_size); + goto out; + } + num_set_index_tuples = shape[0]; + num_bins = PyArray_DIM(bins, 0); + result_dim[0] = num_bins - 1; + result_dim[1] = num_set_index_tuples; + result_matrix = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_FLOAT64, 0); + if (result_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + // clang-format off + Py_BEGIN_ALLOW_THREADS + err = method(self->tree_sequence, num_sample_sets, + PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), num_set_index_tuples, + PyArray_DATA(indexes_array), PyArray_DATA(bins), num_bins, options, PyArray_DATA(result_matrix)); + Py_END_ALLOW_THREADS + // clang-format on + if (err != 0) + { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_matrix; + result_matrix = NULL; +out: + Py_XDECREF(bins); + Py_XDECREF(indexes_array); + Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(result_matrix); + return ret; +} + +static PyObject * +TreeSequence_D2_ij_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_decay(self, args, kwds, 2, tsk_treeseq_D2_ij_decay); +} + +static PyObject * +TreeSequence_D2_ij_unbiased_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_decay( + self, args, kwds, 2, tsk_treeseq_D2_ij_unbiased_decay); +} + +static PyObject * +TreeSequence_r2_ij_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_decay(self, args, kwds, 2, tsk_treeseq_r2_ij_decay); +} + static PyObject * TreeSequence_get_num_mutations(TreeSequence *self) { @@ -8991,6 +9095,18 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_r2_ij_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the two-way r^2 matrix." }, + { .ml_name = "r2_ij_decay", + .ml_meth = (PyCFunction) TreeSequence_r2_ij_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way r2 decay curve." }, + { .ml_name = "D2_ij_decay", + .ml_meth = (PyCFunction) TreeSequence_D2_ij_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way D2 decay curve." }, + { .ml_name = "D2_ij_unbiased_decay", + .ml_meth = (PyCFunction) TreeSequence_D2_ij_unbiased_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way unbiased D2 decay curve." }, { NULL } /* Sentinel */ }; diff --git a/python/tests/test_ld_decay.py b/python/tests/test_ld_decay.py index 4dd1ba92e1..f52ee1e36a 100644 --- a/python/tests/test_ld_decay.py +++ b/python/tests/test_ld_decay.py @@ -69,9 +69,29 @@ def construct_ld_matrix(ts, stat, sample_sets, indexes): """ bp = ts.breakpoints(as_array=True)[:-1] # TODO: output dims - out = np.zeros((1, ts.num_trees, ts.num_trees)) + # __import__("IPython").embed() + # raise Exception + k = len(sample_sets) if indexes is None else len(indexes) + out = np.zeros((k, ts.num_trees, ts.num_trees)) for i, b in enumerate(bp): - out[0, i, i:] = ts.ld_matrix(mode="branch", stat=stat, positions=[[b], bp[i:]]) + out[0:k, i, i:] = ts.ld_matrix( + sample_sets=sample_sets, + indexes=indexes, + mode="branch", + stat=stat, + positions=[[b], bp[i:]], + )[:, 0, :] # result is for one row + # else: + # __import__("IPython").embed() + # raise Exception + # for k in range(max([k for i in indexes for k in i]) + 1): + # out[k, i, i:] = ts.ld_matrix( + # sample_sets=sample_sets, + # mode="branch", + # stat=stat, + # indexes=indexes, + # positions=[[b], bp[i:]], + # ) return out @@ -104,7 +124,7 @@ def integrate_stat_over_bin(bin, i1, i2, stat): def isect(l1, r1, l2, r2): - """Right closed left open""" + """left open, right closed""" return max(l1, l2) < min(r1, r2) or l1 == r2 or l2 == r1 @@ -127,16 +147,6 @@ def ld_decay_branch(ts, bins, stat, sample_sets, indexes): ivl_r = Interval(bp[j], bp[j + 1]) bounds = get_tree_pair_bounds(ivl_l, ivl_r, bins) for k in range(dims[0]): - # for b in bin_ivls: - # if isect(*bounds, *b): - # s = integrate_stat_over_bin(b, ivl_l, ivl_r, ld[k, i, j]) - # print( - # f"{i}\t{j}\t" - # f"{bounds.left:.15f}\t{bounds.right:.15f}\t" - # f"{b[0]:.15f}\t{b[1]:.15f}\t" - # f"{ld[k, i, j]:.15f}\t" - # f"{s:.15f}" - # ) result[k] += np.apply_along_axis( integrate_stat_over_bin, 1, bin_ivls, ivl_l, ivl_r, ld[k, i, j] ) @@ -147,24 +157,26 @@ def ld_decay_branch(ts, bins, stat, sample_sets, indexes): def ld_decay_site(ts, bins, stat, sample_sets, indexes): + # __import__("ipdb").set_trace() ld = ts.ld_matrix(stat=stat, sample_sets=sample_sets, indexes=indexes) dims = (len(indexes or sample_sets), len(bins) - 1) result = np.zeros(dims, dtype=float) bincount = np.zeros(dims, dtype=int) site_pos = ts.sites_position - for i, j in combinations(range(ts.num_sites), 2): # upper tri-diag - dist = site_pos[j] - site_pos[i] - if dist > bins[-1]: - break - bin = np.searchsorted(bins[1:], dist, side="left") - # if bin == 3 and np.isnan(ld[:, i, j]).any(): - # breakpoint() - for k in range(dims[0]): - s = ld[k, i, j] - if np.isnan(s): - continue - result[k, bin] += s - bincount[k, bin] += 1 + for i in range(ts.num_sites): + for j in range(i + 1, ts.num_sites): # upper tri (-diag) + dist = site_pos[j] - site_pos[i] + if dist > bins[-1]: + break + bin = np.searchsorted(bins[1:], dist, side="left") + # if bin == 3 and np.isnan(ld[:, i, j]).any(): + # breakpoint() + for k in range(dims[0]): + s = ld[k, i, j] + if np.isnan(s): + continue + result[k, bin] += s + bincount[k, bin] += 1 if dims[0] == 1: # drop dims if first dim is length 1 return result.reshape(dims[1:]), bincount.reshape(dims[1:]) return result, bincount @@ -190,6 +202,7 @@ def ld_decay( result, count = ld_decay_branch(ts, bins, stat, sample_sets, indexes) case _: raise ValueError(f"Unknown Stats Mode: {mode}") + if return_counts: return result, count with suppress_overflow_div0_warning(): @@ -209,6 +222,8 @@ def ld_decay( "pi2_unbiased", ] +TWO_WAY_STATS = ["r2", "D2", "D2_unbiased"] + TS = msprime.sim_mutations( msprime.sim_ancestry( samples=100, @@ -255,7 +270,32 @@ def test_ld_decay(stat, mode): np.testing.assert_array_almost_equal_nulp( np.nansum(decay), np.nansum(tu), nulp=2 ) - print(f"{stat} diff={np.nansum(decay) - np.nansum(tu)}") elif mode == "site": tu = TS.ld_matrix(stat=stat)[np.triu_indices(TS.num_sites, k=1)] np.testing.assert_allclose(decay.sum(), np.nansum(tu)) + + +@pytest.mark.parametrize("stat,mode", product(ONE_WAY_STATS, ["site", "branch"])) +def test_ld_decay_sample_sets(stat, mode): + bins = np.logspace(0, np.log10(TS.sequence_length), num=35) + bins[0] = 0 + sample_sets = [TS.samples(), TS.samples(), TS.samples()] + decay = TS.ld_decay(bins, sample_sets=sample_sets, stat=stat, mode=mode) + np.testing.assert_array_equal(decay[0], decay[1]) + np.testing.assert_array_equal(decay[1], decay[2]) + + +@pytest.mark.slow +@pytest.mark.parametrize("stat,mode", product(TWO_WAY_STATS, ["site", "branch"])) +def test_two_way_ld_decay(stat, mode): + bins = np.logspace(0, np.log10(TS.sequence_length), num=35) + np.testing.assert_array_almost_equal( + ld_decay(TS, bins, stat=stat, mode=mode), + TS.ld_decay(bins, stat=stat, mode=mode), + ) + ss = [TS.samples()] * 3 + indexes = [(0, 0), (0, 1), (1, 1)] + np.testing.assert_array_almost_equal( + ld_decay(TS, bins, stat=stat, mode=mode, sample_sets=ss, indexes=indexes), + TS.ld_decay(bins, stat=stat, mode=mode, sample_sets=ss, indexes=indexes), + ) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 58fbe32092..6527e89e17 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8374,6 +8374,45 @@ def __two_locus_sample_set_decay_stat( return result + def __k_way_two_locus_sample_set_decay_stat( + self, + ll_method, + k, + sample_sets, + bins, + indexes=None, + mode=None, + ): + sample_set_sizes = np.array( + [len(sample_set) for sample_set in sample_sets], dtype=np.uint32 + ) + if np.any(sample_set_sizes == 0): + raise ValueError("Sample sets must contain at least one element") + flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + drop_dimension = False + indexes = util.safe_np_int_cast(indexes, np.int32) + if len(indexes.shape) == 1: + indexes = indexes.reshape((1, indexes.shape[0])) + drop_dimension = True + if len(indexes.shape) != 2 or indexes.shape[1] != k: + raise ValueError( + "Indexes must be convertable to a 2D numpy array with {} " + "columns".format(k) + ) + result = ll_method( + sample_set_sizes, + flattened, + indexes, + bins, + mode, + ) + if drop_dimension: + result = result.reshape(result.shape[0]) + else: + # Orient the data so that the first dimension is the sample set. + result = result.swapaxes(0, 1) + return result + def __k_way_weighted_stat( self, ll_method, @@ -10923,8 +10962,8 @@ def ld_matrix( stat_func, sample_sets, sites=sites, positions=positions, mode=mode ) - def ld_decay(self, bins, sample_sets=None, mode="site", stat="r2"): - stats = { + def ld_decay(self, bins, sample_sets=None, mode="site", stat="r2", indexes=None): + one_way_stats = { "D": self._ll_tree_sequence.D_decay, "D2": self._ll_tree_sequence.D2_decay, "r2": self._ll_tree_sequence.r2_decay, @@ -10936,12 +10975,27 @@ def ld_decay(self, bins, sample_sets=None, mode="site", stat="r2"): "D2_unbiased": self._ll_tree_sequence.D2_unbiased_decay, "pi2_unbiased": self._ll_tree_sequence.pi2_unbiased_decay, } + two_way_stats = { + "D2": self._ll_tree_sequence.D2_ij_decay, + "D2_unbiased": self._ll_tree_sequence.D2_ij_unbiased_decay, + "r2": self._ll_tree_sequence.r2_ij_decay, + } + stats = one_way_stats if indexes is None else two_way_stats try: stat_func = stats[stat] except KeyError: raise ValueError( f"Unknown two-locus statistic '{stat}', we support: {list(stats.keys())}" ) + if indexes is not None: + return self.__k_way_two_locus_sample_set_decay_stat( + stat_func, + 2, + sample_sets, + bins, + indexes=indexes, + mode=mode, + ) return self.__two_locus_sample_set_decay_stat( stat_func, sample_sets, From 8befb28d70614145486905cbc4049af68a3e37c2 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 18 Nov 2025 03:46:31 -0600 Subject: [PATCH 6/7] get rid of debugging code and extraneous comment --- c/tskit/trees.c | 34 ---------------------------------- python/tests/test_ld_decay.py | 18 +----------------- 2 files changed, 1 insertion(+), 51 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index dcd7305e5f..46a2e928d7 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3581,14 +3581,6 @@ tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t st for (i = 0; i < (num_bins - 1) * result_dim; i++) { result[i] /= bincount[i]; } - // for (k = 0; k < result_dim; k++) { - // bincount_row = GET_2D_ROW(bincount, num_bins - 1, k); - // printf("bincount%lu = { ", k); - // for (i = 0; i < num_bins - 2; i++) { - // printf("%lu, ", bincount_row[i]); - // } - // printf("%lu }\n", bincount_row[i]); - // } out: tsk_safe_free(result_tmp); tsk_safe_free(bincount); @@ -3709,39 +3701,13 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat } result_row[k] += result_tmp[k]; bincount_row[k] += 1; - // if (result_row[0] != result_row[k]) { - // printf("BAD %lu\t%f\t%f\n", k, result_row[0], result_row[k]); - // } } tsk_memset(result_tmp, 0, sizeof(*result_tmp) * result_dim); } } - // puts("result\t\tbincount"); for (i = 0; i < (num_bins - 1) * result_dim; i++) { - // printf("%f\t%lu\n", result[i], bincount[i]); result[i] /= bincount[i]; } - // puts(""); - // puts("======="); - // puts("bincount = {"); - // for (i = 0; i < num_bins - 1; i++) { - // bincount_row = GET_2D_ROW(bincount, result_dim, i); - // printf(" { "); - // for (k = 0; k < result_dim - 1; k++) { - // printf("%lu, ", bincount_row[k]); - // } - // printf("%lu },\n", bincount_row[k]); - // } - // puts("result = {"); - // for (i = 0; i < num_bins - 1; i++) { - // result_row = GET_2D_ROW(result, result_dim, i); - // printf(" { "); - // for (k = 0; k < result_dim - 1; k++) { - // printf("%f, ", result_row[k]); - // } - // printf("%f },\n", result_row[k]); - // } - // puts("}"); out: tsk_safe_free(sites); tsk_safe_free(bincount); diff --git a/python/tests/test_ld_decay.py b/python/tests/test_ld_decay.py index f52ee1e36a..d568b9be0e 100644 --- a/python/tests/test_ld_decay.py +++ b/python/tests/test_ld_decay.py @@ -68,9 +68,6 @@ def construct_ld_matrix(ts, stat, sample_sets, indexes): different (likely more) amount of error. """ bp = ts.breakpoints(as_array=True)[:-1] - # TODO: output dims - # __import__("IPython").embed() - # raise Exception k = len(sample_sets) if indexes is None else len(indexes) out = np.zeros((k, ts.num_trees, ts.num_trees)) for i, b in enumerate(bp): @@ -81,17 +78,6 @@ def construct_ld_matrix(ts, stat, sample_sets, indexes): stat=stat, positions=[[b], bp[i:]], )[:, 0, :] # result is for one row - # else: - # __import__("IPython").embed() - # raise Exception - # for k in range(max([k for i in indexes for k in i]) + 1): - # out[k, i, i:] = ts.ld_matrix( - # sample_sets=sample_sets, - # mode="branch", - # stat=stat, - # indexes=indexes, - # positions=[[b], bp[i:]], - # ) return out @@ -111,7 +97,7 @@ def integrate_stat_over_bin(bin, i1, i2, stat): r2_l = min(max(bl, r2_l_bound), r2_r_bound) r2_r = max(min(br, r2_r_bound), r2_l_bound) r3_l = min(max(bl, r2_r_bound), r_support) - r3_r = max(min(br, r_support), r2_r_bound) # this one differs from mm nb + r3_r = max(min(br, r_support), r2_r_bound) return ( stat / (i1.span * i2.span) @@ -169,8 +155,6 @@ def ld_decay_site(ts, bins, stat, sample_sets, indexes): if dist > bins[-1]: break bin = np.searchsorted(bins[1:], dist, side="left") - # if bin == 3 and np.isnan(ld[:, i, j]).any(): - # breakpoint() for k in range(dims[0]): s = ld[k, i, j] if np.isnan(s): From 8d0f78e3762a890dcb4d5df5ace935ed4f56dfe4 Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 28 Nov 2025 18:29:25 -0600 Subject: [PATCH 7/7] bin boundaries; return bincount; optional ratemap This change fixes the bin boundaries. Now, all bin boundaries for sites will be left closed right open. The boundaries for branches will be closed on both sides, minimizing the number of spanning bins. This is to avoid integrating over a region of 0 length, which would affect the bincount over a region that integrates to 0. Since we're performing an integration, it doesn't matter if the left/right boundaries are opened or closed. Finally, add support for specifying a recombination map. I've performed some minimal testing, but I'm not quite sure if it's working correctly. I have a suspicion that there's a bug somewhere. It will allow the user to specify bin boundaries in cM. Requires the tree to be simulated with the same recombination map. --- c/tskit/trees.c | 124 ++++++++++++++++-------------- c/tskit/trees.h | 42 ++++++---- python/_tskitmodule.c | 141 +++++++++++++++++++++++++++------- python/tests/test_ld_decay.py | 19 ++--- python/tskit/trees.py | 60 +++++++++++++-- 5 files changed, 268 insertions(+), 118 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 46a2e928d7..e1b4ff6efb 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3495,23 +3495,21 @@ tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t st tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, const double *bins, tsk_size_t num_bins, - double *result) + const double *restrict breakpoints, double *result, tsk_size_t *bincount) { int ret = 0; - const double *restrict breakpoints = self->breakpoints; const tsk_size_t num_nodes = self->tables->nodes.num_rows; interval_t bounds, ivl_l, ivl_r; iter_state l_state, r_state; double *result_tmp = NULL, *result_row; tsk_bitset_t node_samples, sample_sets_bits; - tsk_size_t i, j, k, bin_l, bin_r, *bincount = NULL, *bincount_row; + tsk_size_t i, j, k, bin_l, bin_r, *bincount_row; tsk_memset(&sample_sets_bits, 0, sizeof(sample_sets_bits)); tsk_memset(&node_samples, 0, sizeof(node_samples)); tsk_memset(&l_state, 0, sizeof(l_state)); tsk_memset(&r_state, 0, sizeof(r_state)); result_tmp = tsk_malloc(result_dim * sizeof(*result_tmp)); - bincount = tsk_calloc(result_dim * (num_bins - 1), sizeof(*bincount)); if (result_tmp == NULL) { ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; @@ -3549,13 +3547,16 @@ tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t st } for (j = i; j < self->num_trees; j++) { ivl_r = (interval_t){ breakpoints[j], breakpoints[j + 1] }; - bounds = (interval_t){ fmax(0, ivl_r.left - ivl_l.right), + bounds = (interval_t){ fmax(bins[0], ivl_r.left - ivl_l.right), fmin(bins[num_bins - 1], ivl_r.right - ivl_l.left) }; - bin_l = tsk_search_sorted(bins + 1, num_bins - 1, bounds.left); - bin_r = tsk_search_sorted(bins + 1, num_bins - 1, bounds.right); if (bounds.left > bins[num_bins - 1] || bounds.right < bins[0]) { continue; } + bin_l = tsk_search_sorted(bins + 1, num_bins - 1, bounds.left); + bin_r = tsk_search_sorted(bins + 1, num_bins - 1, bounds.right); + if (bin_l + 1 <= bin_r && bins[bin_l + 1] == bounds.left) { + bin_l += 1; + } ret = advance_collect_edges(&r_state, (tsk_id_t) j); if (ret != 0) { goto out; @@ -3569,21 +3570,20 @@ tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t st result_row = GET_2D_ROW(result, result_dim, bin_l); bincount_row = GET_2D_ROW(bincount, result_dim, bin_l); for (k = 0; k < result_dim; k++) { - // TODO: nansum?? - result_row[k] += integrate_stat_over_bin( + double val = integrate_stat_over_bin( ivl_l, ivl_r, bins[bin_l], bins[bin_l + 1], result_tmp[k]); + if (tsk_isnan(val)) { + continue; + } + result_row[k] += val; bincount_row[k] += 1; } bin_l++; } while (bin_l <= bin_r); } } - for (i = 0; i < (num_bins - 1) * result_dim; i++) { - result[i] /= bincount[i]; - } out: tsk_safe_free(result_tmp); - tsk_safe_free(bincount); iter_state_free(&l_state); iter_state_free(&r_state); tsk_bitset_free(&node_samples); @@ -3596,7 +3596,8 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, norm_func_t *norm_f, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *restrict sites_position, tsk_flags_t options, + double *result, tsk_size_t *bincount) { int ret = 0; tsk_bitset_t allele_samples, allele_sample_sets; @@ -3605,9 +3606,7 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat tsk_size_t i, j, k, bin, n_sites, *bincount_row; double dist, *result_row, *result_tmp = NULL; const tsk_size_t num_samples = self->num_samples; - const double *restrict site_position = self->tables->sites.position; - tsk_size_t *bincount = NULL, *num_alleles = NULL, *site_offsets = NULL, - *allele_counts = NULL; + tsk_size_t *num_alleles = NULL, *site_offsets = NULL, *allele_counts = NULL; tsk_size_t max_ss_size = 0, max_alleles = 0, n_alleles = 0; two_locus_work_t work; @@ -3625,7 +3624,6 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles)); site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets)); result_tmp = tsk_malloc(result_dim * sizeof(*result_tmp)); - bincount = tsk_calloc(result_dim * (num_bins - 1), sizeof(*bincount)); if (num_alleles == NULL || site_offsets == NULL || result_tmp == NULL || bincount == NULL) { ret = tsk_trace_error(TSK_ERR_NO_MEMORY); @@ -3669,15 +3667,15 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat sample_sets, self->sample_index_map, &allele_sample_sets, allele_counts); for (i = 0; i < n_sites; i++) { for (j = i + 1; j < n_sites; j++) { - dist = site_position[j] - site_position[i]; - if (dist > bins[num_bins - 1]) { + dist = sites_position[j] - sites_position[i]; + if (dist >= bins[num_bins - 1]) { // right open break; } - // TODO: Very left interval is closed? - if (dist < bins[0]) { + if (dist < bins[0]) { // left closed continue; } - bin = tsk_search_sorted(bins + 1, num_bins - 1, dist); + bin = tsk_search_sorted(bins, num_bins, dist); + bin = bins[bin] > dist ? bin - 1 : bin; // left closed intervals result_row = GET_2D_ROW(result, result_dim, bin); bincount_row = GET_2D_ROW(bincount, result_dim, bin); if (num_alleles[i] == 2 && num_alleles[j] == 2) { @@ -3705,12 +3703,8 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat tsk_memset(result_tmp, 0, sizeof(*result_tmp) * result_dim); } } - for (i = 0; i < (num_bins - 1) * result_dim; i++) { - result[i] /= bincount[i]; - } out: tsk_safe_free(sites); - tsk_safe_free(bincount); tsk_safe_free(result_tmp); tsk_safe_free(num_alleles); tsk_safe_free(site_offsets); @@ -3721,12 +3715,15 @@ tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t stat return ret; } +// In two_locus_decay_stat, we specify positions. These can be site positions or tree +// breakpoints. We pass them in at this level so that we can convert their overall +// positions using a recombination map if we'd like. static int tsk_treeseq_two_locus_decay_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, - norm_func_t *norm_f, const double *bins, tsk_size_t num_bins, tsk_flags_t options, - double *result) + norm_func_t *norm_f, const double *bins, tsk_size_t num_bins, + const double *positions, tsk_flags_t options, double *result, tsk_size_t *bincount) { int ret = 0; bool stat_site = !!(options & TSK_STAT_SITE); @@ -3763,11 +3760,11 @@ tsk_treeseq_two_locus_decay_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl if (stat_site) { ret = tsk_treeseq_two_locus_site_decay_stat(self, state_dim, num_sample_sets, sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, bins, - num_bins, options, result); + num_bins, positions, options, result, bincount); } else if (stat_branch) { ret = tsk_treeseq_two_locus_branch_decay_stat(self, state_dim, num_sample_sets, sample_set_sizes, sample_sets, result_dim, f, &f_params, bins, num_bins, - result); + positions, result, bincount); goto out; } else { ret = TSK_ERR_UNSUPPORTED_STAT_MODE; @@ -4629,12 +4626,13 @@ tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_D_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { options |= TSK_STAT_POLARISED; // TODO: allow user to pick? return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_summary_func, norm_total_weighted, bins, - num_bins, options, result); + num_bins, positions, options, result, bincount); } static int @@ -4678,11 +4676,12 @@ tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_D2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D2_summary_func, norm_total_weighted, bins, - num_bins, options, result); + num_bins, positions, options, result, bincount); } static int @@ -4727,11 +4726,12 @@ tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_r2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, r2_summary_func, norm_hap_weighted, bins, - num_bins, options, result); + num_bins, positions, options, result, bincount); } static int @@ -4781,12 +4781,13 @@ tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_D_prime_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { options |= TSK_STAT_POLARISED; return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_prime_summary_func, norm_total_weighted, - bins, num_bins, options, result); + bins, num_bins, positions, options, result, bincount); } static int @@ -4833,12 +4834,13 @@ tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_r_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { options |= TSK_STAT_POLARISED; return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, r_summary_func, norm_total_weighted, bins, - num_bins, options, result); + num_bins, positions, options, result, bincount); } static int @@ -4883,11 +4885,12 @@ tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_Dz_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, Dz_summary_func, norm_total_weighted, bins, - num_bins, options, result); + num_bins, positions, options, result, bincount); } static int @@ -4929,11 +4932,12 @@ tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_pi2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, pi2_summary_func, norm_total_weighted, bins, - num_bins, options, result); + num_bins, positions, options, result, bincount); } static int @@ -4976,11 +4980,12 @@ tsk_treeseq_D2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_D2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D2_unbiased_summary_func, - norm_total_weighted, bins, num_bins, options, result); + norm_total_weighted, bins, num_bins, positions, options, result, bincount); } static int @@ -5024,11 +5029,12 @@ tsk_treeseq_Dz_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_Dz_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, Dz_unbiased_summary_func, - norm_total_weighted, bins, num_bins, options, result); + norm_total_weighted, bins, num_bins, positions, options, result, bincount); } static int @@ -5072,11 +5078,12 @@ tsk_treeseq_pi2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, int tsk_treeseq_pi2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, pi2_unbiased_summary_func, - norm_total_weighted, bins, num_bins, options, result); + norm_total_weighted, bins, num_bins, positions, options, result, bincount); } /*********************************** @@ -5459,7 +5466,8 @@ int tsk_treeseq_D2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { int ret = 0; ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); @@ -5468,7 +5476,7 @@ tsk_treeseq_D2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, } ret = tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_index_tuples, index_tuples, D2_ij_summary_func, - norm_total_weighted, bins, num_bins, options, result); + norm_total_weighted, bins, num_bins, positions, options, result, bincount); out: return ret; } @@ -5551,7 +5559,8 @@ int tsk_treeseq_D2_ij_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { int ret = 0; ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); @@ -5560,7 +5569,7 @@ tsk_treeseq_D2_ij_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sampl } ret = tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_index_tuples, index_tuples, D2_ij_unbiased_summary_func, - norm_total_weighted, bins, num_bins, options, result); + norm_total_weighted, bins, num_bins, positions, options, result, bincount); out: return ret; } @@ -5629,7 +5638,8 @@ int tsk_treeseq_r2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result) + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) { int ret = 0; ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); @@ -5638,7 +5648,7 @@ tsk_treeseq_r2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, } ret = tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_index_tuples, index_tuples, r2_ij_summary_func, - norm_hap_weighted_ij, bins, num_bins, options, result); + norm_hap_weighted_ij, bins, num_bins, positions, options, result, bincount); out: return ret; } diff --git a/c/tskit/trees.h b/c/tskit/trees.h index eb4e68e30d..7bf6ff3fc2 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1140,58 +1140,70 @@ typedef int k_way_two_locus_count_stat_method(const tsk_treeseq_t *self, typedef int two_locus_decay_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, tsk_size_t num_bins, - tsk_flags_t options, double *result); + const double *positions, tsk_flags_t options, double *result, tsk_size_t *bincount); typedef int k_way_two_locus_decay_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, tsk_size_t num_bins, - tsk_flags_t options, double *result); + const double *positions, tsk_flags_t options, double *result, tsk_size_t *bincount); int tsk_treeseq_D_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); int tsk_treeseq_D2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); int tsk_treeseq_r2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); int tsk_treeseq_D_prime_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); int tsk_treeseq_r_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); int tsk_treeseq_Dz_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); int tsk_treeseq_pi2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); int tsk_treeseq_D2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); int tsk_treeseq_Dz_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); int tsk_treeseq_pi2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); int tsk_treeseq_D2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); int tsk_treeseq_D2_ij_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, tsk_size_t num_bins, - tsk_flags_t options, double *result); + const double *positions, tsk_flags_t options, double *result, tsk_size_t *bincount); int tsk_treeseq_r2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, - tsk_size_t num_bins, tsk_flags_t options, double *result); + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); /* Two way sample set stats */ diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 861ef05cf6..b436cc00dc 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -996,6 +996,16 @@ float64_array_converter(PyObject *py_obj, PyArrayObject **array_out) return array_converter(NPY_FLOAT64, py_obj, array_out); } +static int +optional_float64_array_converter(PyObject *py_obj, PyArrayObject **array_out) +{ + if (py_obj != Py_None) { + return array_converter(NPY_FLOAT64, py_obj, array_out); + } + *array_out = (PyArrayObject *) Py_None; + return 1; +} + /* Note: it doesn't seem to be possible to cast pointers to the actual * table functions to this type because the first argument must be a * void *, so the simplest option is to put in a small shim that @@ -8111,21 +8121,64 @@ TreeSequence_r2_ij_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) return TreeSequence_k_way_ld_matrix(self, args, kwds, 2, tsk_treeseq_r2_ij); } +static int +parse_decay_positions(const tsk_treeseq_t *ts, PyArrayObject *positions_array, + tsk_flags_t options, const double **out) +{ + bool stat_site, stat_branch; + tsk_size_t positions_len; + stat_site = !!(options & TSK_STAT_SITE); + stat_branch = !!(options & TSK_STAT_BRANCH); + + if (!(stat_site || stat_branch)) { + return 0; // mode validation happens later, out as NULL will be fine + } + + if ((PyObject *) positions_array == Py_None) { + *out = stat_site ? ts->tables->sites.position : ts->breakpoints; + } else { + if (PyArray_NDIM(positions_array) != 1) { + PyErr_Format(PyExc_ValueError, "positions must be a 1d array."); + return 1; + } + positions_len = PyArray_DIM(positions_array, 0); + if (stat_site && (ts->tables->sites.num_rows != positions_len)) { + PyErr_Format(PyExc_ValueError, + "site positions must contain one element per site " + "(want a length %lu array).", + ts->tables->sites.num_rows); + return 1; + } else if (stat_branch && (ts->num_trees + 1 != positions_len)) { + PyErr_Format(PyExc_ValueError, + "site positions must contain one element per tree breakpoint" + "(want a length %lu array).", + ts->num_trees + 1); + return 1; + } + *out = PyArray_DATA(positions_array); + } + return 0; +} + static PyObject * TreeSequence_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, two_locus_decay_stat_method *method) { PyObject *ret = NULL; - static char *kwlist[] = { "sample_set_sizes", "sample_sets", "bins", "mode", NULL }; + static char *kwlist[] + = { "sample_set_sizes", "sample_sets", "bins", "positions", "mode", NULL }; PyObject *sample_sets = NULL; PyObject *sample_set_sizes = NULL; PyArrayObject *sample_sets_array = NULL; PyArrayObject *sample_set_sizes_array = NULL; PyArrayObject *bins = NULL; - PyArrayObject *result_matrix = NULL; + PyArrayObject *positions_array = NULL; + PyArrayObject *result_stat_matrix = NULL; + PyArrayObject *result_bincount_matrix = NULL; npy_intp num_bins; npy_intp result_dim[2]; tsk_size_t num_sample_sets; + const double *positions = NULL; char *mode = NULL; tsk_flags_t options = 0; int err; @@ -8133,8 +8186,9 @@ TreeSequence_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO&|s", kwlist, &sample_set_sizes, - &sample_sets, &float64_array_converter, &bins, &mode)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO&|O&s", kwlist, &sample_set_sizes, + &sample_sets, &float64_array_converter, &bins, + &optional_float64_array_converter, &positions_array, &mode)) { goto out; } if (parse_stats_mode(mode, &options) != 0) { @@ -8145,11 +8199,21 @@ TreeSequence_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, != 0) { goto out; } + if (parse_decay_positions(self->tree_sequence, positions_array, options, &positions) + != 0) { + goto out; + } num_bins = PyArray_DIM(bins, 0); result_dim[0] = num_bins - 1; result_dim[1] = num_sample_sets; - result_matrix = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_FLOAT64, 0); - if (result_matrix == NULL) { + result_stat_matrix = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_FLOAT64, 0); + if (result_stat_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + result_bincount_matrix + = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_UINT64, 0); + if (result_bincount_matrix == NULL) { PyErr_NoMemory(); goto out; } @@ -8157,7 +8221,7 @@ TreeSequence_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, Py_BEGIN_ALLOW_THREADS err = method(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), - PyArray_DATA(bins), num_bins, options, PyArray_DATA(result_matrix)); + PyArray_DATA(bins), num_bins, positions, options, PyArray_DATA(result_stat_matrix), PyArray_DATA(result_bincount_matrix)); Py_END_ALLOW_THREADS // clang-format on if (err != 0) @@ -8165,13 +8229,19 @@ TreeSequence_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, handle_library_error(err); goto out; } - ret = (PyObject *) result_matrix; - result_matrix = NULL; + // Return a tuple of matrixes: (stat, count) + ret = PyTuple_New(2); + PyTuple_SET_ITEM(ret, 0, (PyObject *) result_stat_matrix); + PyTuple_SET_ITEM(ret, 1, (PyObject *) result_bincount_matrix); + result_stat_matrix = NULL; + result_bincount_matrix = NULL; out: Py_XDECREF(bins); + Py_XDECREF(positions_array); Py_XDECREF(sample_sets_array); Py_XDECREF(sample_set_sizes_array); - Py_XDECREF(result_matrix); + Py_XDECREF(result_stat_matrix); + Py_XDECREF(result_bincount_matrix); return ret; } @@ -8240,14 +8310,16 @@ TreeSequence_k_way_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, npy_intp tuple_size, k_way_two_locus_decay_stat_method *method) { PyObject *ret = NULL; - static char *kwlist[] - = { "sample_set_sizes", "sample_sets", "indexes", "bins", "mode", NULL }; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "indexes", "bins", + "positions", "mode", NULL }; PyObject *sample_sets = NULL; PyObject *sample_set_sizes = NULL; PyObject *indexes = NULL; PyArrayObject *bins = NULL; + PyArrayObject *positions_array = NULL; PyArrayObject *indexes_array = NULL; - PyArrayObject *result_matrix = NULL; + PyArrayObject *result_stat_matrix = NULL; + PyArrayObject *result_bincount_matrix = NULL; PyArrayObject *sample_sets_array = NULL; PyArrayObject *sample_set_sizes_array = NULL; npy_intp num_bins; @@ -8255,6 +8327,7 @@ TreeSequence_k_way_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, npy_intp result_dim[2]; tsk_size_t num_sample_sets; tsk_size_t num_set_index_tuples; + const double *positions = NULL; char *mode = NULL; tsk_flags_t options = 0; int err; @@ -8262,8 +8335,9 @@ TreeSequence_k_way_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO&|s", kwlist, &sample_set_sizes, - &sample_sets, &indexes, &float64_array_converter, &bins, &mode)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO&|O&s", kwlist, &sample_set_sizes, + &sample_sets, &indexes, &float64_array_converter, &bins, + &optional_float64_array_converter, &positions_array, &mode)) { goto out; } if (parse_stats_mode(mode, &options) != 0) { @@ -8274,6 +8348,10 @@ TreeSequence_k_way_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, != 0) { goto out; } + if (parse_decay_positions(self->tree_sequence, positions_array, options, &positions) + != 0) { + goto out; + } indexes_array = (PyArrayObject *) PyArray_FROMANY( indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY); if (indexes_array == NULL) { @@ -8289,8 +8367,14 @@ TreeSequence_k_way_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, num_bins = PyArray_DIM(bins, 0); result_dim[0] = num_bins - 1; result_dim[1] = num_set_index_tuples; - result_matrix = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_FLOAT64, 0); - if (result_matrix == NULL) { + result_stat_matrix = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_FLOAT64, 0); + if (result_stat_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + result_bincount_matrix + = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_UINT64, 0); + if (result_bincount_matrix == NULL) { PyErr_NoMemory(); goto out; } @@ -8298,7 +8382,7 @@ TreeSequence_k_way_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, Py_BEGIN_ALLOW_THREADS err = method(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), num_set_index_tuples, - PyArray_DATA(indexes_array), PyArray_DATA(bins), num_bins, options, PyArray_DATA(result_matrix)); + PyArray_DATA(indexes_array), PyArray_DATA(bins), num_bins, positions, options, PyArray_DATA(result_stat_matrix), PyArray_DATA(result_bincount_matrix)); Py_END_ALLOW_THREADS // clang-format on if (err != 0) @@ -8306,14 +8390,20 @@ TreeSequence_k_way_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, handle_library_error(err); goto out; } - ret = (PyObject *) result_matrix; - result_matrix = NULL; + // Return a tuple of matrixes: (stat, count) + ret = PyTuple_New(2); + PyTuple_SET_ITEM(ret, 0, (PyObject *) result_stat_matrix); + PyTuple_SET_ITEM(ret, 1, (PyObject *) result_bincount_matrix); + result_stat_matrix = NULL; + result_bincount_matrix = NULL; out: Py_XDECREF(bins); Py_XDECREF(indexes_array); + Py_XDECREF(positions_array); Py_XDECREF(sample_sets_array); Py_XDECREF(sample_set_sizes_array); - Py_XDECREF(result_matrix); + Py_XDECREF(result_stat_matrix); + Py_XDECREF(result_bincount_matrix); return ret; } @@ -10818,8 +10908,8 @@ static PyMethodDef Tree_methods[] = { { .ml_name = "map_mutations", .ml_meth = (PyCFunction) Tree_map_mutations, .ml_flags = METH_VARARGS | METH_KEYWORDS, - .ml_doc - = "Returns a parsimonious state reconstruction for the specified genotypes." }, + .ml_doc = "Returns a parsimonious state reconstruction for the specified " + "genotypes." }, { .ml_name = "equals", .ml_meth = (PyCFunction) Tree_equals, .ml_flags = METH_VARARGS, @@ -12249,9 +12339,8 @@ PyInit__tskit(void) return NULL; } Py_INCREF(&LsHmmType); - PyModule_AddObject(module, "LsHmm", (PyObject *) &LsHmmType); - - /* IdentitySegments type */ + PyModule_AddObject( + module, "LsHmm", (PyObject *) &LsHmmType); /* IdentitySegments type */ if (PyType_Ready(&IdentitySegmentsType) < 0) { return NULL; } diff --git a/python/tests/test_ld_decay.py b/python/tests/test_ld_decay.py index d568b9be0e..3f0d5747f2 100644 --- a/python/tests/test_ld_decay.py +++ b/python/tests/test_ld_decay.py @@ -1,5 +1,4 @@ import contextlib -from itertools import combinations from itertools import combinations_with_replacement from itertools import product @@ -25,12 +24,9 @@ def expand_dims(arr): necessary in the C implementation because dimension dropping happens in the python layer. """ - try: - arr = np.asarray(arr) - if arr.ndim == 1: - return np.expand_dims(arr, axis=0) - except: - pass + arr = np.asarray(arr) + if arr.ndim == 1: + return np.expand_dims(arr, axis=0) try: arr = [np.asarray(a) for a in arr] except Exception as e: @@ -110,13 +106,13 @@ def integrate_stat_over_bin(bin, i1, i2, stat): def isect(l1, r1, l2, r2): - """left open, right closed""" - return max(l1, l2) < min(r1, r2) or l1 == r2 or l2 == r1 + "left closed, right open, left is ivl and right is query" + return max(l1, l2) < min(r1, r2) def get_tree_pair_bounds(ivl_l, ivl_r, bins): return Interval( - max(0, ivl_r.left - ivl_l.right), + max(bins[0], ivl_r.left - ivl_l.right), min(bins[-1], ivl_r.right - ivl_l.left), ) @@ -143,7 +139,6 @@ def ld_decay_branch(ts, bins, stat, sample_sets, indexes): def ld_decay_site(ts, bins, stat, sample_sets, indexes): - # __import__("ipdb").set_trace() ld = ts.ld_matrix(stat=stat, sample_sets=sample_sets, indexes=indexes) dims = (len(indexes or sample_sets), len(bins) - 1) result = np.zeros(dims, dtype=float) @@ -250,7 +245,7 @@ def test_ld_decay(stat, mode): dmask = np.diag_indices_from(tu) tu[dmask] = tu[dmask] / 2 # we take half the density on the diagonal np.testing.assert_allclose(np.nansum(decay), np.nansum(tu)) - # all but r2 D2 Dz are within 1 ulp + # all but r2 D2 Dz are within 1 ulp, likely due to numerical precision np.testing.assert_array_almost_equal_nulp( np.nansum(decay), np.nansum(tu), nulp=2 ) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 6527e89e17..09aa89cf2d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8352,6 +8352,8 @@ def __two_locus_sample_set_decay_stat( ll_method, sample_sets, bins, + return_counts, + ratemap=None, mode=None, ): if sample_sets is None: @@ -8363,16 +8365,28 @@ def __two_locus_sample_set_decay_stat( ) if np.any(sample_set_sizes == 0): raise ValueError("Sample sets must contain at least one element") - flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) - result = ll_method(sample_set_sizes, flattened, bins, mode) + positions = None + if ratemap is not None: + # rate in cM + if mode is None or mode == "site": + positions = ratemap.get_cumulative_mass(self.sites_position) * 100 + elif mode == "branch": + positions = ( + ratemap.get_cumulative_mass(self.breakpoints(as_array=True)) * 100 + ) + result, counts = ll_method(sample_set_sizes, flattened, bins, positions, mode) if drop_dimension: result = result.reshape(result.shape[0]) + counts = counts.reshape(counts.shape[0]) else: # Orient the data so that the first dimension is the sample set. result = result.swapaxes(0, 1) - - return result + counts = counts.swapaxes(0, 1) + if return_counts: + return result, counts + with np.errstate(divide="ignore", invalid="ignore"): + return result / counts def __k_way_two_locus_sample_set_decay_stat( self, @@ -8380,7 +8394,9 @@ def __k_way_two_locus_sample_set_decay_stat( k, sample_sets, bins, + return_counts, indexes=None, + ratemap=None, mode=None, ): sample_set_sizes = np.array( @@ -8399,19 +8415,34 @@ def __k_way_two_locus_sample_set_decay_stat( "Indexes must be convertable to a 2D numpy array with {} " "columns".format(k) ) - result = ll_method( + positions = None + if ratemap is not None: + # rate in cM + if mode is None or mode == "site": + positions = ratemap.get_cumulative_mass(self.sites_position) * 100 + elif mode == "branch": + positions = ( + ratemap.get_cumulative_mass(self.breakpoints(as_array=True)) * 100 + ) + result, counts = ll_method( sample_set_sizes, flattened, indexes, bins, + positions, mode, ) if drop_dimension: result = result.reshape(result.shape[0]) + counts = counts.reshape(counts.shape[0]) else: # Orient the data so that the first dimension is the sample set. result = result.swapaxes(0, 1) - return result + counts = counts.swapaxes(0, 1) + if return_counts: + return result, counts + with np.errstate(divide="ignore", invalid="ignore"): + return result / counts def __k_way_weighted_stat( self, @@ -10962,7 +10993,16 @@ def ld_matrix( stat_func, sample_sets, sites=sites, positions=positions, mode=mode ) - def ld_decay(self, bins, sample_sets=None, mode="site", stat="r2", indexes=None): + def ld_decay( + self, + bins, + sample_sets=None, + mode="site", + stat="r2", + indexes=None, + ratemap=None, + return_counts=False, + ): one_way_stats = { "D": self._ll_tree_sequence.D_decay, "D2": self._ll_tree_sequence.D2_decay, @@ -10993,13 +11033,17 @@ def ld_decay(self, bins, sample_sets=None, mode="site", stat="r2", indexes=None) 2, sample_sets, bins, + return_counts, indexes=indexes, + ratemap=ratemap, mode=mode, ) return self.__two_locus_sample_set_decay_stat( stat_func, sample_sets, - bins=bins, + bins, + return_counts, + ratemap=ratemap, mode=mode, )