From 00e2db0e8d4580f22e03d79c8081e324855eb429 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 5 Dec 2025 09:53:58 -0500 Subject: [PATCH 01/41] First commit on superyacht branch, including files required to calculate effective coverage, etc according to sylph (Shaw and Yu, 2024). --- cov_calc.py | 280 ++++++++++++ hypothesis_recovery_src.py | 480 ++++++++++++++++++++ internal_superyacht_test.py | 71 +++ utils.py | 860 ++++++++++++++++++++++++++++++++++++ 4 files changed, 1691 insertions(+) create mode 100644 cov_calc.py create mode 100644 hypothesis_recovery_src.py create mode 100755 internal_superyacht_test.py create mode 100755 utils.py diff --git a/cov_calc.py b/cov_calc.py new file mode 100644 index 00000000..2d5ea5f5 --- /dev/null +++ b/cov_calc.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python +import sourmash +import math +import numpy as np +import pandas as pd +from loguru import logger +from yacht.utils import ratio_lambda +from yacht.utils import mme_lambda +from yacht.utils import binary_search_lambda +from yacht.utils import mle_zip +from yacht.utils import bootstrap_interval +from yacht.utils import ani_from_lambda +from yacht.utils import _ContainArgs +from yacht.utils import AniResult +from yacht.utils import AdjustStatusLambda, AdjustStatusLow, AdjustStatusHigh, AdjustStatusNone +from scipy.stats import poisson, variation +from typing import Optional, Tuple, Dict, Any + +SAMPLE_SIZE_CUTOFF: int = 25 #using the sylph (Shaw and Yu, 2024) defaults here +PVALUE_CUTOFF: float = 0.9999999999 +MEDIAN_ANI_THRESHOLD: float = 2.00 +MAX_MEDIAN_FOR_MEAN_FINAL_EST: float = 15.0 +MIN_COUNT_THRESH=3 +ksize=31 + +no_adj = False #consider updating this in future SUPERYACHT arguments +winner_map = None #skipping this step in this version +kmers_lost_count = None + +# Creates instances of the simple states +ADJUST_STATUS_NONE = AdjustStatusNone() +ADJUST_STATUS_HIGH = AdjustStatusHigh() +ADJUST_STATUS_LOW = AdjustStatusLow() + +# Define a Union type hint for clarity +AdjustStatus = AdjustStatusLambda | AdjustStatusHigh | AdjustStatusLow | AdjustStatusNone + +def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.SourmashSignature): + """ + Function that calculates lambda according to Shaw and Yu (2024) from two sourmash.Minshash files (resresenting the sample and the genome sketches). + """ + + myArgs = _ContainArgs() + + gn_hashes = genome_sig.minhash.hashes + gn_kmers_keys = genome_sig.minhash.hashes.keys() + gn_kmers_items = genome_sig.minhash.hashes.items() + gn_dict = dict(gn_kmers_items) + + sample_hashes_keys = sample_sig.minhash.hashes.keys() + samp_kmers_items = sample_sig.minhash.hashes.items() + samp_dict = dict(samp_kmers_items) + + covs = [] + contain_count = 0 + for kmer in gn_hashes: + #print(kmer) + if kmer in sample_hashes_keys: + #print(f"Overlap") + if samp_dict[kmer] == 0: + continue + contain_count += 1 + covs.append(samp_dict[kmer]) + + if len(covs)==0: + return None + + naive_ani = math.pow(contain_count/len(gn_kmers_items), + 1/ksize) + + covs.sort() + #print(covs) + + if len(covs) == 0: + #print("Zero length") + covs.append(0) + + + #cov_set = set(covs) + len_ind = len(covs)//2 + #print("len_ind") + #print(len_ind) + median_cov = covs[len(covs)//2] + #print(median_cov) + pois_obj = poisson(median_cov) #creates a discrete frozen Poisson distribution object + cov_max = float('inf') + + if median_cov < 30: #if median coverage of 30 is not fulfilled + #print(f"Below 30") + for i in range(len_ind,len(covs), 1): + #print(i) + cov = covs[i] + if pois_obj.cdf(cov) < PVALUE_CUTOFF: + cov_max = cov + else: + break #consider adding RaiseError if (e.g.) cov_max=Inf + + full_covs = [0] * (len(gn_hashes) - contain_count) + + for cov in covs: + if cov <= cov_max: + full_covs.append(cov) + var = variation(full_covs) + #print("Variation is:", var) + if var is not None: + logger.debug("VAR {} {}", var, genome_sig.name) + + mean_cov = sum(full_covs)//len(full_covs) + geq1_mean_cov = sum(full_covs)//len(covs) + if median_cov > MEDIAN_ANI_THRESHOLD: + return_lambda = ADJUST_STATUS_HIGH + #print(f"Above_threshold: {type(return_lambda).__name__}") #for testing + + else: + if (myArgs.ratio == True): + test_lambda = ratio_lambda(full_covs, MIN_COUNT_THRESH) + elif (myArgs.mme == True): + test_lambda = mme_lambda(full_covs) + elif (myArgs.bin == True): + test_lambda = binary_search_lambda(full_covs) + elif (myArgs.mle) == True: + test_lambda = mle_zip(full_covs, gn_kmers_items) + else: + test_lambda = ratio_lambda(full_covs, MIN_COUNT_THRESH) + + if test_lambda is None: + return_lambda = ADJUST_STATUS_LOW #updated code + else: + return_lambda = AdjustStatusLambda(value=test_lambda) # Wrap the float in the dataclass + + #print(f"Return lambda type: {type(return_lambda).__name__}") + + match return_lambda: + + case AdjustStatusLambda(value=lam): + #print(f"Case1") + # executes if it is the Lambda case + final_est_cov = lam + opt_lambda = final_est_cov + #print(f"Status is Lambda, coverage set to: {final_est_cov:.2f}") + + case AdjustStatusHigh(): + # executes if it is high coverage case + #print(f"Case2") + if median_cov < MAX_MEDIAN_FOR_MEAN_FINAL_EST: + final_est_cov = geq1_mean_cov + #print(f"Status is High, using geq1_mean_cov logic") + else: + final_est_cov = median_cov + #print(f"Status is High, using median_cov logic") + opt_lambda = final_est_cov + + case AdjustStatusLow(): + #print(f"Case3") + # if it is the "low" case + # final_est_cov logic is handled elsewhere, or use a default + opt_lambda = None + #print("Status is Low, using naive_ani logic later") + + # Adding a "wild-card" case, just to be safe + case _: + #print(f"Case Wildcard: Unexpected value or type {return_lambda}") + opt_lambda = None + + #print(f"Opt_lambda") + #print(opt_lambda) + + #print(f"opt_est_ani") + opt_est_ani = ani_from_lambda(opt_lambda, mean_cov, 31, full_covs) + #print(opt_est_ani) + + if opt_lambda == None or opt_est_ani == None or no_adj == True: + final_est_ani = naive_ani + else: + final_est_ani = opt_est_ani + +#### This is the "winner_map" situation. I'm leaving it out of the codebase for now, but we can revisit this + +# Calculate min_ani using a conditional expression (Python's 'if/else if/else') + #if args.minimum_ani is not None: + # min_ani = args.minimum_ani / 100.0 + #elif args.pseudotax: + # min_ani = MIN_ANI_P_DEF + #else: + # min_ani = MIN_ANI_DEF + + # Check the final estimated ANI against the calculated minimum + #if final_est_ani < min_ani: + # Use 'is not None' to check for optional values (like Rust's is_some()) + # if winner_map is not None: + # Check if we should log the reassignment event + # if log_reassign: + # logging.info( + # "Genome/contig %s/%s has ANI = %.2f < %.2f after reassigning %d k-mers (%d contained k-mers after reassign)", + # genome_sketch.file_name, + # genome_sketch.first_contig_name, + # final_est_ani * 100.0, + # min_ani * 100.0, + # kmers_lost_count, + # contain_count + # ) + +######## End winner_map section + + low_ani, high_ani, low_lambda, high_lambda= None, None, None, None + +#Conditional calculation of confidence intervals + + if myArgs.ci_int==True and opt_lambda is not None: + bootstrap = bootstrap_interval(full_covs, ksize, myArgs) + low_ani = bootstrap[0] + high_ani = bootstrap[1] + low_lambda = bootstrap[2] + high_lambda = bootstrap[3] + + #print(f"ci_values are as follows:") #for testing + #print(low_ani, high_ani, low_lambda, high_lambda) + + if sample_sig.name: + seq_name = sample_sig.name + else: + seq_name = sample_sig.filename + +#This is more code related to the winner_map situation + kmers_lost = kmers_lost_count if winner_map is not None else None + + ani_result = AniResult( + naive_ani=naive_ani, + final_est_ani=final_est_ani, + final_est_cov=opt_lambda, + seq_name=seq_name, + gn_name=genome_sig.filename, + contig_name=genome_sig.name, + mean_cov=geq1_mean_cov, + median_cov=median_cov, + containment_index=(contain_count, len(gn_hashes)), + lambda_status=return_lambda, + ani_ci=(low_ani, high_ani), + lambda_ci=(low_lambda, high_lambda), + genome_sketch=genome_sig, + rel_abund=None, + seq_abund=None, + kmers_lost=kmers_lost, + ) + + results = [ + AniResult( + naive_ani=naive_ani, final_est_ani=final_est_ani, final_est_cov=opt_lambda, + seq_name=seq_name, gn_name=genome_sig.filename, contig_name=genome_sig.name, + mean_cov=geq1_mean_cov, median_cov=median_cov, containment_index=(contain_count, len(gn_hashes)), + lambda_status=return_lambda, ani_ci=(low_ani, high_ani), lambda_ci=(low_lambda, high_lambda), + genome_sketch=genome_sig, rel_abund=None, seq_abund=None, kmers_lost=kmers_lost, + )] + + columns_ani = [ + "naive_ani", # the ani according to naive calculations + "final_est_ani", # final estimated ani + "final_est_cov", # The final estimated coverage + "seq_name", # the name of the sequence + "gn_name", # the name of the genome + "mean_cov", # the mean coverage observed + "median_cov", # the median coverage observed + "containment_index", #the containment index + "lambda_status", #lambda status + "ani_ci", #ani confidence interval + "lambda_ci", #lambda confidence interval + "genome_sketch", #genome Sourmash signature file + "rel_abund", #The taxonomic abundance observed (set to None for now) + "seq_abund", #the number of sequenes that match a given genome (set to None for now) + "kmers_lost" #the number of kmers that were reassigned during the tax triage step. (set to None for now) + ] + + cov_calc_df = pd.DataFrame(results, columns=columns_ani) + + return cov_calc_df + + + + + \ No newline at end of file diff --git a/hypothesis_recovery_src.py b/hypothesis_recovery_src.py new file mode 100644 index 00000000..750bb403 --- /dev/null +++ b/hypothesis_recovery_src.py @@ -0,0 +1,480 @@ +import os +import sys +import numpy as np +import warnings +from scipy.stats import binom +from scipy.special import betaincinv +import pandas as pd +import zipfile +from tqdm import tqdm +from multiprocessing import Pool +import sourmash +import glob +from typing import List, Set, Tuple +from .utils import load_signature_with_ksize, decompress_all_sig_files +# Configure Loguru logger +from loguru import logger +from cov_calc import cov_calc + +warnings.filterwarnings("ignore") + + +\ +logger.remove() +logger.add( + sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO" +) + +SIG_SUFFIX = ".sig" + + +def get_organisms_with_nonzero_overlap( + manifest: pd.DataFrame, + sample_file: str, + scale: int, + ksize: int, + num_threads: int, + path_to_genome_temp_dir: str, + path_to_sample_temp_dir: str, +) -> List[str]: + """ + This function runs the sourmash multisearch to find the organisms that have non-zero overlap with the sample. + :param manifest: a dataframe with the following columns: + 'organism_name', + 'md5sum', + 'num_unique_kmers_in_genome_sketch', + 'num_total_kmers_in_genome_sketch', + 'genome_scale_factor', + 'num_exclusive_kmers_in_sample_sketch', + 'num_total_kmers_in_sample_sketch', + 'sample_scale_factor', + 'min_coverage' + :param sample_file: string (path to the sample signature file) + :param scale: int (scale factor) + :param ksize: string (size of kmer) + :param num_threads: int (number of threads to use for parallelization) + :param path_to_genome_temp_dir: string (path to the genome temporary directory generated by the training step) + :param path_to_sample_temp_dir: string (path to the sample temporary directory) + :return: a list of organism names that have non-zero overlap with the sample + """ + # run the sourmash multisearch + # prepare the input files for the sourmash multisearch + # unzip the sourmash signature file to the temporary directory + logger.info("Unzipping the sample signature zip file") + with zipfile.ZipFile(sample_file, "r") as sample_zip_file: + sample_zip_file.extractall(path_to_sample_temp_dir) + all_gz_files = glob.glob(f"{path_to_sample_temp_dir}/signatures/*.sig.gz") + # decompress all signature files + logger.info(f"Decompressing {len(all_gz_files)} .sig.gz files using {num_threads} threads.") + decompress_all_sig_files(all_gz_files, num_threads) + + sample_sig_file = pd.DataFrame( + [ + os.path.join(path_to_sample_temp_dir, "signatures", sig_file) + for sig_file in os.listdir( + os.path.join(path_to_sample_temp_dir, "signatures") + ) + ] + ) + sample_sig_file_path = os.path.join(path_to_sample_temp_dir, "sample_sig_file.txt") + sample_sig_file.to_csv(sample_sig_file_path, header=False, index=False) + + organism_sig_file = pd.DataFrame( + [ + os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX) + for md5sum in manifest["md5sum"] + ] + ) + organism_sig_file_path = os.path.join( + path_to_sample_temp_dir, "organism_sig_file.txt" + ) + organism_sig_file.to_csv(organism_sig_file_path, header=False, index=False) + + # run the sourmash multisearch + cmd = f"sourmash scripts multisearch {sample_sig_file_path} {organism_sig_file_path} -s {scale} -k {ksize} -c {num_threads} -t 0 -o {os.path.join(path_to_sample_temp_dir, 'sample_multisearch_result.csv')}" + logger.info(f"Running sourmash multisearch with command: {cmd}") + exit_code = os.system(cmd) + if exit_code != 0: + raise ValueError(f"Error running sourmash multisearch with command: {cmd}") + + # read the multisearch result, only if the file is not empty + multisearch_result_file = os.path.join(path_to_sample_temp_dir, "sample_multisearch_result.csv") + try: + multisearch_result = pd.read_csv( + multisearch_result_file, + sep=",", + header=0, + ) + except pd.errors.EmptyDataError: + print('ERROR: Multisearch file is empty. Likely there are no microorganisms in your sample, or something went wrong', flush=True) + exit(0) + + multisearch_result = multisearch_result.drop_duplicates().reset_index(drop=True) + + return multisearch_result["match_name"].to_list() + + +def get_exclusive_hashes( + manifest: pd.DataFrame, + nontrivial_organism_names: List[str], + sample_sig: sourmash.SourmashSignature, + ksize: int, + path_to_genome_temp_dir: str, +) -> Tuple[List[Tuple[int, int]], pd.DataFrame]: + """ + This function gets the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample, and + then find how many are in the sample. + :param manifest: a dataframe with the following columns: + 'organism_name', + 'md5sum', + 'num_unique_kmers_in_genome_sketch', + 'num_total_kmers_in_genome_sketch', + 'genome_scale_factor', + 'num_exclusive_kmers_in_sample_sketch', + 'num_total_kmers_in_sample_sketch', + 'sample_scale_factor', + 'min_coverage' + :param nontrivial_organism_names: a list of organism names that have non-zero overlap with the sample + :param sample_sig: the sample signature + :param ksize: int (size of kmer) + :param path_to_genome_temp_dir: string (path to the genome temporary directory generated by the training step) + :return: + a list of tuples, each tuple contains the following information: + 1. the number of unique hashes exclusive to the organism under consideration + 2. the number of unique hashes exclusive to the organism under consideration that are in the sample + a new manifest dataframe that only contains the organisms that have non-zero overlap with the sample + """ + pvalue_cutoff=0.9999999999 + min_count_thresh=3 #TODO: consider whether to change this value + + def __find_exclusive_hashes( + md5sum: str, + path_to_temp_dir: str, + ksize: int, + single_occurrence_hashes: Set[int], + ) -> Set[int]: + # load genome signature + sig = load_signature_with_ksize( + os.path.join(path_to_temp_dir, "signatures", md5sum + SIG_SUFFIX), ksize + ) + return {hash for hash in sig.minhash.hashes if hash in single_occurrence_hashes} + + # get manifest information for the organisms that have non-zero overlap with the sample + sub_manifest = manifest.loc[ + manifest["organism_name"].isin(nontrivial_organism_names), : + ].reset_index(drop=True) + organism_md5sum_list = sub_manifest["md5sum"].to_list() + + single_occurrence_hashes: Set[int] = set() # Corrected type annotation + multiple_occurrence_hashes: Set[int] = set() + for md5sum in tqdm(organism_md5sum_list, desc="Processing organism signatures"): + sig = load_signature_with_ksize( + os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), + ksize, + ) + for hash in sig.minhash.hashes: + if hash in multiple_occurrence_hashes: + continue + elif hash in single_occurrence_hashes: + single_occurrence_hashes.remove(hash) + multiple_occurrence_hashes.add(hash) + else: + single_occurrence_hashes.add(hash) + + + #print(multiple_occurrence_hashes) + + del multiple_occurrence_hashes # free up memory + + # Find hashes that are unique to each organism + logger.info("Finding hashes that are unique to each organism") + exclusive_hashes_org = [] + for md5sum in tqdm(organism_md5sum_list, desc="Finding exclusive hashes"): + exclusive_hashes_org.append( + __find_exclusive_hashes( + md5sum, path_to_genome_temp_dir, ksize, single_occurrence_hashes + ) + ) + + #print(f"Single occurrence hashes") + #print(single_occurrence_hashes) #adding this for testing + del single_occurrence_hashes # free up memory + + # Get sample hashes + sample_hashes = set(sample_sig.minhash.hashes) + + # Get sample hashes keys + sample_hashes_keys = sample_sig.minhash.hashes.keys() + #print(sample_hashes_keys) + samp_kmers_items = sample_sig.minhash.hashes.items() + samp_dict = dict(samp_kmers_items) + + stats_list = [] + for md5sum in tqdm(organism_md5sum_list, desc="Processing coverage per organism"): + sig = load_signature_with_ksize( + os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), + ksize, + ) + stats_out = cov_calc(sample_sig, sig) #location of cov_calc, which calculates effective coverage and other things according to Shaw and Yu (2024) + stats_list.append(stats_out) + + final_stats_df = pd.concat(stats_list, ignore_index=True) + + del stats_list # free up memory + + # Find hashes that are unique to each organism and in the sample + logger.info("Finding hashes that are unique to each organism and in the sample") + exclusive_hashes_info = [] + for i, exclusive_hashes in enumerate( + tqdm(exclusive_hashes_org, desc="Matching exclusive hashes with sample") + ): + #print(f"exclusive_hash", exclusive_hashes) + exclusive_hashes_info.append( + (len(exclusive_hashes), len(exclusive_hashes.intersection(sample_hashes))) + ) + + # Calculate lambda and other related coverage metrics for each organism in the sample + #logger.info("Calculate lambda for each organism in the sample") + #for i, lambda_stats in enumerate() + + #print(type(exclusive_hashes_info)) + #print(exclusive_hashes_info) + #print(type(sub_manifest)) + + columns_of_interest = [ + 'naive_ani', + 'final_est_ani', + 'final_est_cov', + 'mean_cov', + 'median_cov', + 'lambda_ci', + 'ani_ci' + ] + + # Select only those columns from the DataFrame + selected_data = final_stats_df[columns_of_interest] + + summary_stats = selected_data.describe() + + #print(sub_manifest) + print(final_stats_df) + print(summary_stats) + #print(final_stats_df['lambda_ci'].unique()) + #print(final_stats_df['ani_ci'].unique()) + + return exclusive_hashes_info, sub_manifest, final_stats_df + +def get_alt_mut_rate( + nu: int, thresh: int, ksize: int, significance: float = 0.99 +) -> float: + """ + Computes the alternative mutation rate for a given significance level. I.e. how much higher would the mutation rate + have needed to be in order to have a false positive rate of significance (since we are setting the false negative + rate to significance by design)? + :param nu: int (Number of k-mers exclusive to the organism under consideration) + :param thresh: Number of exclusive k-mers I would need to observe in order to reject the null hypothesis (i.e. + accept that the organism is present) + :param ksize: int (k-mer size) + :param significance: value between 0 and 1 expressing the desired false positive rate (and by design, the false + negative rate) + :return: float (alternative mutation rate; how much higher would the mutation rate have needed to be in order to + make FP and FN rates equal to significance) + """ + # Replace binary search with the regularized incomplete Gamma function inverse: Solve[significance == + # BetaRegularized[1 - (1 - mutCurr)^k, nu - thresh, + # 1 + thresh], mutCurr] + # per mathematica + mut = 1 - (1 - betaincinv(nu - thresh, 1 + thresh, significance)) ** (1 / ksize) + return -1.0 if np.isnan(mut) else mut + + +def single_hyp_test( + exclusive_hashes_info_org: Tuple[int, int], + ksize: int, + significance: float = 0.99, + ani_thresh: float = 0.95, + min_coverage: int = 1, +) -> Tuple[bool, float, int, int, int, int, float, float]: + """ + Performs a single hypothesis test for the presence of a genome in a metagenome. + :param exclusive_hashes_info_org: a tuple containing the following information: + 1. the number of unique hashes exclusive to this genome under consideration + 2. the number of unique hashes exclusive to this genome under consideration that are in the sample + :param ksize: int (k-mer size) + :param significance: float (significance level for the hypothesis test) + :param ani_thresh: threshold for ANI (i.e. how similar do the genomes need to be in order to be considered the same) + :param min_coverage: minimum coverage of the genome under consideration in the metagenome (float in [0, 1]) + :return: A whole bunch of stuff + """ + # get the number of unique k-mers + num_exclusive_kmers = exclusive_hashes_info_org[0] + #print(exclusive_hashes_info_org) ##printing the output of this to determine what the data structure looks like + # mutation rate + non_mut_p = (ani_thresh) ** ksize + # # assuming coverage of 1, how many unique k-mers would I need to observe in order to reject the null hypothesis? + # acceptance_threshold_wo_coverage = binom.ppf(1-significance, num_exclusive_kmers, non_mut_p) + # # what is the actual confidence of the test? + # actual_confidence_wo_coverage = 1-binom.cdf(acceptance_threshold_wo_coverage, num_exclusive_kmers, non_mut_p) + # number of unique k-mers I would see given a coverage of min_coverage + num_exclusive_kmers_coverage = int(num_exclusive_kmers * min_coverage) + # how many unique k-mers would I need to observe in order to reject the null hypothesis, + # assuming coverage of min_cov? + acceptance_threshold_with_coverage = binom.ppf( + 1 - significance, num_exclusive_kmers_coverage, non_mut_p + ) + # what is the actual confidence of the test, assuming coverage of min_cov? + actual_confidence_with_coverage = 1 - binom.cdf( + acceptance_threshold_with_coverage, num_exclusive_kmers_coverage, non_mut_p + ) + # # what is the alternative mutation rate? I.e. how much higher would the mutation rate (resp. how low of ANI) + # # have needed to be in order to have a false positive rate of significance + # # (since we are setting the false negative rate to significance by design)? + # alt_confidence_mut_rate = get_alt_mut_rate(num_exclusive_kmers, acceptance_threshold_wo_coverage, ksize, + # significance=significance) + # same as above, but assuming coverage of min_cov + alt_confidence_mut_rate_with_coverage = get_alt_mut_rate( + num_exclusive_kmers_coverage, + acceptance_threshold_with_coverage, + ksize, + significance=significance, + ) + + # How many unique k-mers do I actually see? + num_matches = exclusive_hashes_info_org[1] + #print(num_matches) #printing this for testing + # calculate the p-value considering the coverage + if num_matches <= num_exclusive_kmers_coverage: + p_val = binom.cdf(num_matches, num_exclusive_kmers_coverage, non_mut_p) + else: + p_val = 1.0 + # is the genome present? Takes coverage into account + in_sample_est = (num_matches >= acceptance_threshold_with_coverage) and ( + num_matches != 0 + ) + # return in_sample_est, p_val, num_exclusive_kmers, num_exclusive_kmers_coverage, num_matches, \ + # acceptance_threshold_wo_coverage, acceptance_threshold_with_coverage, actual_confidence_wo_coverage, \ + # actual_confidence_with_coverage, alt_confidence_mut_rate, alt_confidence_mut_rate_with_coverage + return ( + in_sample_est, + p_val, + num_exclusive_kmers, + num_exclusive_kmers_coverage, + num_matches, + acceptance_threshold_with_coverage, + actual_confidence_with_coverage, + alt_confidence_mut_rate_with_coverage, + ) + + +def hypothesis_recovery( + manifest: pd.DataFrame, + sample_info_set: Tuple[str, sourmash.SourmashSignature], + path_to_genome_temp_dir: str, + min_coverage_list: List[float], + scale: int, + ksize: int, + significance: float = 0.99, + ani_thresh: float = 0.95, + num_threads: int = 16, +): + """ + Go through each of the organisms that have non-zero overlap with the sample and perform a hypothesis test for the + presence of that organism in the sample: have we seen enough k-mers exclusive to that organism to conclude that + an organism with ANI > ani_thresh (to the one under consideration) is present in the sample? + :param manifest: a dataframe with the following columns: + 'organism_name', + 'md5sum', + 'num_unique_kmers_in_genome_sketch', + 'num_total_kmers_in_genome_sketch', + 'genome_scale_factor', + 'num_exclusive_kmers_in_sample_sketch', + 'num_total_kmers_in_sample_sketch', + 'sample_scale_factor', + 'min_coverage' + :param sample_info_set: a set of information about the sample, including the sample signature location and the sample signature object + :param path_to_genome_temp_dir: path to the genome temporary directory generated by the training step + :param min_coverage_list: a list of minimum coverage values + :param scale: scale factor + :param ksize: k-mer size + :param significance: significance level for the hypothesis test + :param ani_thresh: threshold for ANI (i.e. how similar do the genomes need to be in order to be considered the same) + :param num_threads: number of threads to use for parallelization + :return: a list of pandas dataframe with the results of the hypothesis tests based on different min_coverage values + """ + + # unpack the sample info set + sample_file, sample_sig = sample_info_set + + # create a temporary directory for the sample + sample_dir = os.path.dirname(sample_file) + sample_name = os.path.basename(sample_file).replace(".sig.zip", "") + path_to_sample_temp_dir = os.path.join( + sample_dir, f"sample_{sample_name}_intermediate_files" + ) + if os.path.exists(path_to_sample_temp_dir): + # if exists, remove it + logger.info(f"Removing existing temporary directory: {path_to_sample_temp_dir}") + os.system(f"rm -rf {path_to_sample_temp_dir}") + os.makedirs(path_to_sample_temp_dir) + + # Find the organisms that have non-zero overlap with the sample + nontrivial_organism_names = get_organisms_with_nonzero_overlap( + manifest, + sample_file, + scale, + ksize, + num_threads, + path_to_genome_temp_dir, + path_to_sample_temp_dir, + ) + + # Get the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample + exclusive_hashes_info, manifest, final_stats_df = get_exclusive_hashes( + manifest, nontrivial_organism_names, sample_sig, ksize, path_to_genome_temp_dir + ) + + # Set up the results dataframe columns + # n.b. that the output of cov_calc is not being incorporated here; instead it's being returned separately, as a pandas dataframe. + given_columns = [ + "in_sample_est", # Main output: Boolean indicating whether genome is present in sample + "p_vals", # Probability of observing this or more extreme result at ANI threshold. + "num_exclusive_kmers_to_genome", # Number of k-mers exclusive to genome + "num_exclusive_kmers_to_genome_coverage", # Number of k-mers exclusive to genome, assuming coverage of min_cov + "num_matches", # Number of k-mers exclusive to genome that are present in the sample + # 'acceptance_threshold_wo_coverage', # Acceptance threshold without adjusting for coverage + # (how many k-mers need to be present in order to reject the null hypothesis) + "acceptance_threshold_with_coverage", # Acceptance threshold with adjusting for coverage + # 'actual_confidence_wo_coverage', # Actual confidence without adjusting for coverage + "actual_confidence_with_coverage", # Actual confidence with adjusting for coverage + # 'alt_confidence_mut_rate', # Mutation rate such that at this mutation rate, false positive rate = p_val. + # Does not account for min_coverage parameter. + "alt_confidence_mut_rate_with_coverage", # same as above, but accounting for min_coverage parameter + ] + + # Using multiprocessing.Pool to parallelize the execution + manifest_list = [] + for min_coverage in tqdm(min_coverage_list, desc="Computing hypothesis recovery"): + logger.info(f"Computing hypothesis recovery for min_coverage={min_coverage}") + with Pool(processes=num_threads) as p: + params = ( + ( + exclusive_hashes_info[i], + ksize, + significance, + ani_thresh, + min_coverage, + ) + for i in range(len(exclusive_hashes_info)) + ) + results = p.starmap(single_hyp_test, params) + print(f"Finished computing all results for min_coverage value: {min_coverage}") #for testing + + # Create a pandas dataframe to store the results + results = pd.DataFrame(results, columns=given_columns) + #print(results) #for testing + + # combine the results with the manifest + manifest["min_coverage"] = min_coverage + manifest_list.append(pd.concat([manifest, results], axis=1)) + + return manifest_list, final_stats_df diff --git a/internal_superyacht_test.py b/internal_superyacht_test.py new file mode 100755 index 00000000..d3f66664 --- /dev/null +++ b/internal_superyacht_test.py @@ -0,0 +1,71 @@ +import pandas as pd +import zipfile +import glob +import sys +import os +import yacht +import sourmash +from multiprocessing import Pool +import multiprocessing +from yacht.hypothesis_recovery_src import get_exclusive_hashes +from yacht.hypothesis_recovery_src import get_organisms_with_nonzero_overlap +from yacht.hypothesis_recovery_src import hypothesis_recovery +from yacht.utils import decompress_all_sig_files +from yacht.utils import load_signature_with_ksize +from yacht.utils import load_one_sig +""" +A script that reproduces the results of hypothesis_recovery_src so that the sylph coverage model can be incorporated +""" +manifestFilePath="/Users/randolph.raborn/Desktop/YACHT/demo/query_data/gtdb-rs214-reps.k31_0.95_pretrained/gtdb-rs214-reps.k31_0.95_processed_manifest.tsv" +sampleSigPath="/Users/randolph.raborn/Desktop/YACHT/demo/new_demo_files/SRR6940089_sample_out.sig.zip" +genTempDir="/Users/randolph.raborn/Desktop/YACHT/demo/query_data/gtdb_ani_thresh_0.95_intermediate_files/" +sampleTempDir="/Users/randolph.raborn/Desktop/YACHT/demo/new_demo_files/genome_reference.sig_temp" +multisearch_result_file = "/Users/randolph.raborn/Desktop/YACHT/demo/query_data/gtdb-rs214-reps.k31_0.95_pretrained/gtdb-rs214-reps.k31_0.95_intermediate_files/training_multisearch_result.csv" +multisearchResultPath= "/Users/randolph.raborn/Desktop/YACHT/demo/new_demo_files/genome_reference.sig_temp/sample_multisearch_result.csv" +ksize=31 +min_cov=0.10 #setting to 0.10 + +sample_sig = load_signature_with_ksize(sampleSigPath, ksize) +print(sample_sig) +sample_info_set = (sampleSigPath, sample_sig) + +#importing the manifest: +manifest = pd.read_csv(manifestFilePath, sep="\t", header=0) +#print(manifest.head()) + +#following the procedure at the end of yacht's `get_organisms_with_nonzero_overlap` +#multisearch_result = pd.read_csv( +# multisearch_result_file, +# sep=",", +# header=0, +# ) + +#print(multisearch_result) +#multisearch_result_new = multisearch_result.drop_duplicates().reset_index(drop=True) +#print(multisearch_result_new) +#multisearch_result_names = multisearch_result["match_name"].to_list() #this is what is actually returned by get_organisms_with_nonzero_overlap +#print(set(multisearch_result_names)) + +multiprocessing.set_start_method('fork') +#multisearch_result_new2 = get_organisms_with_nonzero_overlap(manifest, sampleSigPath, 1000, 31, 2, genTempDir, sampleTempDir) #comment this out to import directly from the multisearch result file +#print(multisearch_result_new2) +#multisearch_result_new2_file = pd.read_csv(multisearchResultPath, sep=",", header=0) +#multisearch_result_new2_file2 = multisearch_result_new2_file.drop_duplicates().reset_index(drop=True) + +#print(f"Manifest") +#print(manifest.head()) +#print(f"Multisearch_object") +#print(multisearch_result_new2_file2.head()) + +#multisearch_result_names2 = multisearch_result_new2_file["match_name"].to_list() +print(f"Made it here") +#print(multisearch_result_names2) +#exclusive_hashes_out = get_exclusive_hashes(manifest, multisearch_result_names2, sample_sig, ksize, genTempDir) +#print(exclusive_hashes_out) + +#creating a minimum coverage list +min_cov_list = [min_cov] * len(manifest) + +hyp_rec_out = hypothesis_recovery(manifest=manifest, sample_info_set=sample_info_set, path_to_genome_temp_dir=genTempDir, min_coverage_list=min_cov_list, scale=1000, ksize=ksize, ani_thresh=0.85, num_threads=16) + +print(hyp_rec_out) \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100755 index 00000000..38afcc54 --- /dev/null +++ b/utils.py @@ -0,0 +1,860 @@ +import os +import sys +import numpy as np +from typing import List, Optional +from collections import Counter +from sourmash import load_one_signature +from collections import defaultdict +from tqdm import tqdm +import pandas as pd +from multiprocessing import Pool +from loguru import logger +from typing import Optional, List, Set, Dict, Tuple, Any +import shutil +import gzip +import math +import random +import sourmash +from dataclasses import dataclass +from glob import glob + +# Configure Loguru logger +logger.remove() +logger.add( + sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO" +) + +# Set up constants +COL_NOT_FOUND_ERROR = "Column not found: {}" +FILE_LOCATION = os.path.dirname(os.path.realpath(__file__)) +# Adding two more contstants (RTR) +SAMPLE_SIZE_CUTOFF: int = 25 +PVALUE_CUTOFF: float = 0.9999999999 +ksize = 31 #Note: hard-coding this for now + +# Set up global variables +__version__ = "2.0.1" +GITHUB_API_URL = "https://api.github.com/repos/KoslickiLab/YACHT/contents/demo/{path}" +GITHUB_RAW_URL = "https://raw.githubusercontent.com/KoslickiLab/YACHT/main/demo/{path}" +BASE_URL = "https://farm.cse.ucdavis.edu/~ctbrown/sourmash-db/" +ZENODO_COMMUNITY_URL = "https://zenodo.org/api/records/?communities=yacht&size=100" + +# A dataclass to implement something equivalent to sylph's rust-based enum implementation (AdjustStatus) +@dataclass(frozen=True) +class AdjustStatusLambda: + value: float + +class AdjustStatusHigh: + pass + +class AdjustStatusLow: + pass + +class AdjustStatusNone: + pass + +# Class for cov_calc output +@dataclass +class AniResult: + naive_ani: float + final_est_ani: float + final_est_cov: float + seq_name: str + gn_name: str + contig_name: str + mean_cov: float + median_cov: float + containment_index: Tuple[int, int] + lambda_status: AdjustStatusLambda + ani_ci: Tuple[Optional[float], Optional[float]] + lambda_ci: Tuple[Optional[float], Optional[float]] + genome_sketch: Any + rel_abund: Optional[float] + seq_abund: Optional[float] + kmers_lost: Optional[int] + +# Class for cov_calc arguments +# Note that these are being set here for now, but could be brought as CLI arguments for yacht +class _ContainArgs: + def __init__(self): + self.ci_int = True + self.ratio = True + self.mme = True + self.nb = True + self.mle = True + self.min_count_correct = 1 # Example value + +def load_signature_with_ksize(filename: str, ksize: int) -> sourmash.SourmashSignature: + """ + Helper function that loads the signature for a given kmer size from the provided signature file. + Filename should point to a sourmash signature file. Raises exception if given kmer size is not present in the file. + This is a wrapper of sourmash.load_file_as_signatures, and accept all types of format: .sig, .sig.zip, .sbt, .lca, and .sqldb. + However, this function specifically ask for 1 signature so lca format is not appropriate; as of sourmash v4.8, sqldb doesn't accept "abund" parameter for signatures. + :param filename: string (location of the signature file of any format: .sig, .sig.zip, .sbt, .lca, and .sqldb) + :param ksize: int (size of kmer) + :return: sourmash signature + """ + # Take the first sample signature with the given kmer size + sketches = list(sourmash.load_file_as_signatures(filename, ksize=ksize)) + if len(sketches) != 1: + raise ValueError( + f"Expected exactly one signature with ksize {ksize} in {filename}, found {len(sketches)}" + ) + if len(sketches[0].minhash.hashes) == 0: + raise ValueError( + "Empty sketch in signature. This may be due to too high of a scale factor, please reduce it, eg. --scaled=1, and try again." + ) + return sketches[0] + + +def get_num_kmers( + minhash_mean_abundance: Optional[float], + minhash_hashes_len: int, + minhash_scaled: int, + scale: bool = True, +) -> int: + """ + Helper function that estimates the total number of kmers in a given sample. + :param minhash_mean_abundance: float or None (mean abundance of the signature) + :param minhash_hashes_len: int (number of hashes in the signature) + :param minhash_scaled: int (scale factor of the signature) + :param scale: bool (whether to scale the number of kmers by the scale factor) + :return: int (estimated total number of kmers) + """ + # Abundances may not have been kept, in which case, just use 1 + if minhash_mean_abundance: + num_kmers = minhash_mean_abundance * minhash_hashes_len + else: + num_kmers = minhash_hashes_len + if scale: + num_kmers *= minhash_scaled + return int(np.round(num_kmers)) + + +def check_file_existence(file_path: str, error_description: str) -> None: + """ + Helper function that checks if a file exists. If not, raises a ValueError with the given error description. + :param file_path: string (location of the file) + :param error_description: string (description of the error) + :return: None + """ + if not os.path.exists(file_path): + raise ValueError(error_description) + + +def get_info_from_single_sig( + sig_file: str, ksize: int +) -> Tuple[str, str, float, int, int]: + """ + Helper function that gets signature information (raw file path, name, md5sum, minhash mean abundance, minhash_hashes_len, minhash scaled) from a single sourmash signature file. + :param sig_file: string (location of the signature file with .sig.gz format) + :param ksize: int (size of kmer) + :return: tuple (name, md5sum, minhash mean abundance, minhash_hashes_len, minhash scaled) + """ + try: + sig = load_signature_with_ksize(sig_file, ksize) + return ( + sig_file, + sig.name, + sig.md5sum(), + sig.minhash.mean_abundance, + len(sig.minhash.hashes), + sig.minhash.scaled, + ) + except: + logger.warning(f"CANNOT extract the relevant info from the signature file: {sig_file}") + return None + +def run_yacht_train_core( + num_threads: int, ani_thresh: float, ksize: int, path_to_temp_dir: str, sig_info_dict: Dict[str, Tuple[str, float, int, int, str]], num_genome_threshold: int = 1000000 +) -> Dict[str, List[str]]: + """ + Helper function that runs the cpp script developed by Mahmudur Rahman Hera to find the closely related genomes with ANI > ani_thresh from the reference database, + then remove them, and generate a dataframe with the selected genomes. + :param num_threads: int (number of threads to use) + :param ani_thresh: float (threshold for ANI, below which we consider two organisms to be distinct) + :param ksize: int (size of kmer) + :param path_to_temp_dir: string (path to the folder to store the intermediate files) + :return: a dataframe containing the selected reference signature information + """ + + # run Mahmudur's cpp for genome comparison + # save signature files to a text file + sig_files = pd.DataFrame( + [ + os.path.join(path_to_temp_dir, "signatures", file) + for file in os.listdir(os.path.join(path_to_temp_dir, "signatures")) + ] + ) + sig_files_path = os.path.join(path_to_temp_dir, "training_sig_files.tsv") + sig_files.to_csv(sig_files_path, header=False, index=False) + + # convert ani threshold to containment threshold + containment_thresh = ani_thresh**ksize + total_sig_files = len(sig_files) + if total_sig_files <= num_genome_threshold: + passes = 1 + else: + passes = int(total_sig_files / num_genome_threshold) + 1 + cmd = f"{FILE_LOCATION}/run_yacht_train_core -t {num_threads} -c {containment_thresh} -p {passes} {sig_files_path} {path_to_temp_dir} {os.path.join(path_to_temp_dir, 'selected_result.tsv')}" + logger.info(f"Running comparison algorithm with command: {cmd}") + exit_code = os.system(cmd) + if exit_code != 0: + raise ValueError(f"Error running comparison algorithm with command: {cmd}") + + # move all split comparison files to a single foldr + os.makedirs(os.path.join(path_to_temp_dir, "comparison_files"), exist_ok=True) + for file in glob(os.path.join(path_to_temp_dir, "*.txt")): + shutil.move(file, os.path.join(path_to_temp_dir, "comparison_files")) + + # get info from the signature files of selected genomes + selected_sig_files = pd.read_csv(os.path.join(path_to_temp_dir, 'selected_result.tsv'), sep="\t", header=None) + selected_sig_files = selected_sig_files[0].to_list() + + # get the mapping from signature file name to genome name + mapping = {sig_info_dict[name][-1]:name for name in sig_info_dict} + selected_genome_names_set = set([mapping[sig_file_path] for sig_file_path in selected_sig_files]) + + # remove the close related organisms from the reference genome list + manifest_df = [] + for sig_name, ( + md5sum, + minhash_mean_abundance, + minhash_hashes_len, + minhash_scaled, + _ + ) in tqdm(sig_info_dict.items(), desc="Removing close related organisms from the reference genome list"): + if sig_name in selected_genome_names_set: + manifest_df.append( + ( + sig_name, + md5sum, + minhash_hashes_len, + get_num_kmers( + minhash_mean_abundance, + minhash_hashes_len, + minhash_scaled, + False, + ), + minhash_scaled, + ) + ) + manifest_df = pd.DataFrame( + manifest_df, + columns=[ + "organism_name", + "md5sum", + "num_unique_kmers_in_genome_sketch", + "num_total_kmers_in_genome_sketch", + "genome_scale_factor", + ], + ) + + return manifest_df + + + +def collect_signature_info( + num_threads: int, ksize: int, path_to_temp_dir: str +) -> Dict[str, Tuple[str, float, int, int]]: + """ + Helper function that collects signature information (raw file path, name, md5sum, minhash mean abundance, minhash_hashes_len, minhash scaled) from a sourmash signature database. + :param num_threads: int (number of threads to use) + :param ksize: int (size of kmer) + :param path_to_temp_dir: string (path to the folder to store the intermediate files) + :return: a dictionary mapping signature name to a tuple (md5sum, minhash mean abundance, minhash_hashes_len, minhash scaled, raw file path) + """ + ## extract in parallel + with Pool(num_threads) as p: + signatures = p.starmap( + get_info_from_single_sig, + [ + (os.path.join(path_to_temp_dir, "signatures", file), ksize) + for file in os.listdir(os.path.join(path_to_temp_dir, "signatures")) + ], + ) + + return {sig[1]: (sig[2], sig[3], sig[4], sig[5], sig[0]) for sig in tqdm(signatures) if sig} + + +class Prediction: + def __init__(self): + self._rank = None + self._taxid = None + self._percentage = None + self._taxpath = None + self._taxpathsn = None + + @property + def rank(self): + return self._rank + + @rank.setter + def rank(self, value): + self._rank = value + + @property + def taxid(self): + return self._taxid + + @taxid.setter + def taxid(self, value): + self._taxid = value + + @property + def percentage(self): + return self._percentage + + @percentage.setter + def percentage(self, value): + self._percentage = value + + @property + def taxpath(self): + return self._taxpath + + @taxpath.setter + def taxpath(self, value): + self._taxpath = value + + @property + def taxpathsn(self): + return self._taxpathsn + + @taxpathsn.setter + def taxpathsn(self, value): + self._taxpathsn = value + + def get_dict(self): + return self.__dict__ + + def get_pretty_dict(self): + return { + property.split("_")[1]: value + for property, value in self.__dict__.items() + if property.startswith("_") + } + + def get_metadata(self): + return { + "rank": self._rank, + "taxpath": self._taxpath, + "taxpathsn": self._taxpathsn, + } + + +def get_column_indices( + column_name_to_index: Dict[str, int] +) -> Tuple[int, int, int, int, Optional[int]]: + """ + (thanks to https://github.com/CAMI-challenge/OPAL, this function is modified get_column_indices from its load_data.py) + Helper function that gets the column indices for the following columns: TAXID, RANK, PERCENTAGE, TAXPATH, TAXPATHSN + :param column_name_to_index: dictionary mapping column name to column index + :return: indices for TAXID, RANK, PERCENTAGE, TAXPATH, TAXPATHSN + """ + # Assuming all other indices are mandatory and only index_taxpathsn can be optional + index_taxpathsn: Optional[int] = None # Correctly annotated to allow None + + if "TAXID" not in column_name_to_index: + logger.error(COL_NOT_FOUND_ERROR.format("TAXID")) + raise RuntimeError + if "RANK" not in column_name_to_index: + logger.error(COL_NOT_FOUND_ERROR.format("RANK")) + raise RuntimeError + if "PERCENTAGE" not in column_name_to_index: + logger.error(COL_NOT_FOUND_ERROR.format("PERCENTAGE")) + raise RuntimeError + if "TAXPATH" not in column_name_to_index: + logger.error(COL_NOT_FOUND_ERROR.format("TAXPATH")) + raise RuntimeError + index_taxid = column_name_to_index["TAXID"] + index_rank = column_name_to_index["RANK"] + index_percentage = column_name_to_index["PERCENTAGE"] + index_taxpath = column_name_to_index["TAXPATH"] + if "TAXPATHSN" in column_name_to_index: + index_taxpathsn = column_name_to_index["TAXPATHSN"] + else: + index_taxpathsn = None + return index_rank, index_taxid, index_percentage, index_taxpath, index_taxpathsn + + +def get_cami_profile( + cami_content: List[str], +) -> List[Tuple[str, Dict[str, str], List[Prediction]]]: + header: Dict[str, str] = {} # Dictionary mapping strings to strings + profile: List[Prediction] = [] # List of Prediction objects + predictions_dict: Dict[ + str, Prediction + ] = {} # Mapping from string to Prediction object + """ + (thanks to https://github.com/CAMI-challenge/OPAL, this function is modified open_profile_from_tsv from its load_data.py) + Helper function that opens a CAMI profile file and returns sample profiling information. + params:cami_content: list of strings (lines of the CAMI profile file) + return: list of tuples (sample_id, header, profile) + """ + header = {} + column_name_to_index = {} + profile = [] + samples_list = [] + predictions_dict = {} + reading_data = False + got_column_indices = False + + for line in cami_content: + if len(line.strip()) == 0 or line.startswith("#"): + continue + line = line.rstrip("\n") + + # parse header with column indices + if line.startswith("@@"): + for index, column_name in enumerate(line[2:].split("\t")): + column_name_to_index[column_name] = index + ( + index_rank, + index_taxid, + index_percentage, + index_taxpath, + index_taxpathsn, + ) = get_column_indices(column_name_to_index) + got_column_indices = True + reading_data = False + continue + + # parse header with metadata + if line.startswith("@"): + # if last line contained sample data and new header starts, store profile for sample + if reading_data: + if "SAMPLEID" in header and "VERSION" in header and "RANKS" in header: + if len(profile) > 0: + samples_list.append((header["SAMPLEID"], header, profile)) + profile = [] + predictions_dict = {} + else: + logger.error( + "Header is incomplete. Check if the header of each sample contains at least SAMPLEID, VERSION, and RANKS." + ) + raise RuntimeError + header = {} + reading_data = False + got_column_indices = False + key, value = line[1:].split(":", 1) + header[key.upper()] = value.strip() + continue + + if not got_column_indices: + logger.error( + "Header line starting with @@ is missing or at wrong position." + ) + raise RuntimeError + + reading_data = True + row_data = line.split("\t") + + taxid = row_data[index_taxid] + # if there is already a prediction for taxon, only sum abundance + if taxid in predictions_dict: + prediction = predictions_dict[taxid] + prediction.percentage += float(row_data[index_percentage]) + else: + if int(float(row_data[index_percentage])) == 0: + continue + prediction = Prediction() + predictions_dict[taxid] = prediction + prediction.taxid = row_data[index_taxid] + prediction.rank = row_data[index_rank] + prediction.percentage = float(row_data[index_percentage]) + prediction.taxpath = row_data[index_taxpath] + if isinstance(index_taxpathsn, int): + prediction.taxpathsn = row_data[index_taxpathsn] + else: + prediction.taxpathsn = None + profile.append(prediction) + + # store profile for last sample + if "SAMPLEID" in header and "VERSION" in header and "RANKS" in header: + if reading_data and len(profile) > 0: + samples_list.append((header["SAMPLEID"], header, profile)) + else: + logger.error( + "Header is incomplete. Check if the header of each sample contains at least SAMPLEID, VERSION, and RANKS." + ) + raise RuntimeError + + return samples_list + + +def create_output_folder(outfolder): + """ + Helper function that creates the output folder if it does not exist. + :param outfolder: location of output folder + :return: None + """ + if not os.path.exists(outfolder): + logger.info(f"Creating output folder: {outfolder}") + os.makedirs(outfolder) + + +def check_download_args(args, db_type): + """ + Helper function that checks if the input arguments are valid. + :param args: input arguments + :param db_type: type of database options: "pretrained" or "default" + :return: None + """ + if args.database not in ["genbank", "gtdb"]: + logger.error( + f"Invalid database: {args.database}. Now only support genbank and gtdb." + ) + sys.exit(1) + + if args.k not in [21, 31, 51]: + logger.error(f"Invalid k: {args.k}. Now only support 21, 31, and 51.") + sys.exit(1) + + if args.database == "genbank": + if args.ncbi_organism is None: + logger.warning( + "No NCBI organism specified using parameter --ncbi_organism. Using the default: bacteria" + ) + args.ncbi_organism = "bacteria" + + if args.ncbi_organism not in [ + "archaea", + "bacteria", + "fungi", + "virus", + "protozoa", + ]: + logger.error( + f"Invalid NCBI organism: {args.ncbi_organism}. Now only support archaea, bacteria, fungi, virus, and protozoa." + ) + sys.exit(1) + + if db_type == "pretrained" and args.ncbi_organism == "virus": + logger.error("We now haven't supported for virus database.") + sys.exit(1) + + +def _decompress_and_remove(file_path: str) -> None: + """ + Decompresses a GZIP-compressed file and removes the original compressed file. + :param file_path: The path to the .sig.gz file that needs to be decompressed and deleted. + :return: None + """ + try: + output_filename = os.path.splitext(file_path)[0] + with gzip.open(file_path, 'rb') as f_in: + with open(output_filename, 'wb') as f_out: + f_out.write(f_in.read()) + + os.remove(file_path) + + except Exception as e: + logger.info(f"Failed to process {file_path}: {e}") + +def decompress_all_sig_files(sig_files: List[str], num_threads: int) -> None: + """ + Decompresses all .sig.gz files in the list using multiple threads. + :param sig_files: List of .sig.gz files that need to be decompressed. + :param num_threads: Number of threads to use for decompression. + :return: None + """ + with Pool(num_threads) as p: + p.map(_decompress_and_remove, sig_files) + + logger.info("All .sig.gz files have been decompressed.") + +def load_one_sig(sig_path: str, ksize: int): + """ + Imports a sourmash signature file using the path and the ksize. + """ + loaded_sig = load_one_signature(sig_path, ksize, select_moltype='DNA').to_mutable() + + logger.info("The sig file has been loaded." + ) + return(loaded_sig) + +def newton_raphson(ratio: float, mean: float): + """ + Shaw and Yu (2024)'s implmentation of Newton-Raphson use to assist in the calculation of lambda. + """ + curr = mean / (1 - ratio) + #print(1-mean) + #print(1-ratio) + for _ in range(1000): #iterates to converge on an approximation for the root + t1 = (1 - ratio) * curr + e_curr = math.exp(-curr) + t2 = mean * (1 - e_curr) + t3 = 1 - ratio + t4 = mean * e_curr + curr = curr - (t1 - t2) / (t3 - t4) + return curr + +def mle_zip(full_covs: list[int], _k: float): + """ + Maximum likelihood estimator for the zero-inflated Poisson (ZIP) distribution from Shaw and Yu (2024) + """ + n_zero = 0 + count_set = Counter() #creating a new counter set + + for x in full_covs: + if x == 0: + n_zero += 1 + else: + count_set[x] +=1 + + if len(count_set) == 1: #If no info for inference, retuns None + return None + + if len(full_covs) - n_zero < SAMPLE_SIZE_CUTOFF: + return None + + mean = np.mean(full_covs) + nr_input = num_zero/len(full_covs) + lambda_out = newton_raphson(nr_input, mean) + + if lambda_out < 0 or math.isnan(lambda_out): + lambda_ret = None + else: + lambda_ret = lambda_out + return lambda_ret + +def variance(data: str(int)): + """ + An internal function that calculates the variance, which is the average of the squared differences of all values from the mean + """ + if len(data) == 0: + return None + mean = np.mean(data) + var = 0 + for x in data: + var += (float(x) - mean) * (float(x) - mean) + var_out = float(var / len(data)) + return var_out + +def ratio_lambda(full_covs: list[int], min_count_correct): + """ + Estimating lambda according to Shaw and Yu (2024) using the ratio method + """ + n_zero = 0 + count_map = defaultdict(int) # Creates an empty dictionary for integer values equivalent to FxHashMap::default() + for x in full_covs: + if x == 0: + n_zero += 1 + else: + count_map[x] +=1 + if len(count_map) == 1: # Absent info for inference, returns None. + return None + + if (len(full_covs) - n_zero) < SAMPLE_SIZE_CUTOFF: + return None + else: + sort_vec = [(value, key) for key, value in count_map.items()] + sort_vec.sort(reverse=True) #sorting the vector in reverse order + most_ind = sort_vec[0][1] + if (most_ind + 1) not in count_map: + return None + + count_p1 = count_map[(most_ind + 1)] + count = count_map[most_ind] + + if count_p1 < min_count_correct or count < min_count_correct: + return None + + lambda_out = count_p1 / (count * (most_ind + 1)) + return lambda_out + +def r_moments_lambda(m: float, v: float, lambda_out: float): + """ + Internal function used used in the calculation of lambda using ratio from moments + """ + result = lambda_out / (v - 1 + lambda_out + m) + return result + +def ratio_calc(val: float, r: float, lambda_out: float): + """ + Internal function used used in the calculation of lambda using ratio from moments + """ + if (r < 100): + return gamma(r + val + 1) / (val + 1) / gamma(r + val) * lambda_out / (r + lambda_out) + else: + return (r + val + 1) / (val + 1) * lambda_out / (r + lambda_out) + +def ratio_from_moments_lambda(val: float, lambda_out: float, m: float, v: float): + """ + Function that calculates lambda using the ratio from moments formula + """ + r = r_moments_lambda(m, v, lambda_out) + if r < 0: + return None + rat_return = ratio_calc(val, r, lambda_out) + return rat_return + +def mme_lambda(full_covs: list[int]) -> Optional[float]: + """ + Calculates the "method of moments" estimator for the lambda parameter + from a list of coverage values. + """ + num_zero = 0 + count_set = set() + + for x in full_covs: + if x == 0: + num_zero += 1 + else: + count_set.add(x) + if len(count_set) == 1: + return None + # Does the number of non-zero observations meet the cutoff threshold? + if len(full_covs) - num_zero < SAMPLE_SIZE_CUTOFF: + return None + # Calculating mean and variance + try: + mean_val = np.mean(full_covs) + variance_val = variance(full_covs) + except ValueError: + return None + # Calculating lambda using the MME formula + lambda_val = variance_val / mean_val + mean_val - 1.0 + + # Ensuring lambda is non-negative + if lambda_val < 0.0: + return None + else: + return lambda_val + +def binary_search_lambda(full_covs: list[int]): + if len(full_covs) == 0: + return None + m = mean(full_covs) + v = var(full_covs) + nonzero = 0 + ones = 0 + twos = 0 + + for x in full_covs: + if x != 0: + _nonzero += 1 + if x == 1: + ones += 1 + elif x == 2: + twos += 1 + + ratio_est = float(twos) / float(ones) + + left = float(max(0.003, m - 2)) + right = m + 5 + endpoints = ("start", "end") + left, right = endpoints + best = None + best_val = 10000 + for i in range(10000): + test = (endpoints - endpoints)/10000 * float(i) + endpoints + proposed = ratio_from_moments_lambda(1, test, m, v); + if proposed is not None: + p = proposed - ratio_est + if abs(p) < best_val: + best_val = abs(p) + best = test + + if best == None: + return None + #consider putting in a RaiseError statement here + r = r_moments_lambda(m, v, best) + # removed debugging code from sylph (see inference.rs for reference) + return best + +def bootstrap_interval(covs_full: list[int], k: float, args: _ContainArgs): + """ + This function calculates bootstrap confidence intervals for ANI and lambda. + """ + if args.ci_int == False: + return (None, None, None, None) + + #logger.info("Bootstrap interval") + #print(f"Bootstrap interval") #for testing #TODO 12/3: look into whether/where this function is being activated + num_samp = len(covs_full) + iters = 100 + res_ani = [] + res_lambda = [] + + for _ in range(iters): + rand_vec = [] + for _ in range(num_samp): + rand_vec.append(random.choice(covs_full)) + if args.ratio: + lambda_val = ratio_lambda(rand_vec, args.min_count_correct) + elif args.mme: + lambda_val = mme_lambda(rand_vec) + elif args.nb: + lambda_val = binary_search_lambda(rand_vec) + elif args.mle: + lambda_val = mle_zip(rand_vec, ksize) + else: + lambda_val = ratio_lambda(rand_vec, args.min_count_correct) + + #print(f"lambda_val is:") #for testing + #print(lambda_val) + + ani_val = ani_from_lambda(lambda_val, np.mean(rand_vec), ksize, rand_vec) + + if ani_val is not None and lambda_val is not None: + if not np.isnan(ani_val) and not np.isnan(lambda_val): + res_ani.append(ani_val) + res_lambda.append(lambda_val) + + res_ani.sort() + res_lambda.sort() + + if len(res_ani) < 50: + return (None, None, None, None) + + suc = len(res_ani) + low_ani = res_ani[suc * 5 // 100] + high_ani = res_ani[suc * 95 // 100] + low_lambda = res_lambda[suc * 5 // 100] + high_lambda = res_lambda[suc * 95 // 100] + + #print(f"Bootstrap interval") #for testing + #print(low_ani, high_ani, low_lambda, high_lambda) + return (low_ani, high_ani, low_lambda, high_lambda) + +def ani_from_lambda(lambda_val, lam_mean, k_value, full_cov): + """ + Calculates an adjusted ani value + Args: + lambda_val: An optional float used to calculate ani + lam_mean: A float value (unused in the original logic). + k: A float value used as the inverse exponent. + full_cov: A list of integers to analyze for non-zero counts. + + Returns: + An optional float representing the calculated adjusted index 'ani', + or None if the input lambda is None, or if 'ani' is negative or NaN. + """ + if lambda_val == None: + return None + contain_count = 0 + zero_count = 0 + for x in full_cov: + if x != 0: + contain_count += 1 + + else: + zero_count += 1 + + if not full_cov: + return None + + adj_index = contain_count / (1.0 - math.exp(-lambda_val)) / len(full_cov) + + # Calculating ani using math.pow for clarity (and to maintain type correctness) + ani = math.pow(adj_index, 1.0 / k_value) + + if ani < 0.0 or math.isnan(ani): + ret_ani = None + + else: + ret_ani = ani + + return ret_ani \ No newline at end of file From 21dc0cba701dbe0b3104ee4d006e87621c470d95 Mon Sep 17 00:00:00 2001 From: Daniel Standage Date: Fri, 5 Dec 2025 10:17:31 -0500 Subject: [PATCH 02/41] Move files [skip ci] --- hypothesis_recovery_src.py | 480 --------------- cov_calc.py => src/yacht/cov_calc.py | 0 src/yacht/hypothesis_recovery_src.py | 73 ++- src/yacht/utils.py | 361 ++++++++++- utils.py | 860 --------------------------- 5 files changed, 424 insertions(+), 1350 deletions(-) delete mode 100644 hypothesis_recovery_src.py rename cov_calc.py => src/yacht/cov_calc.py (100%) mode change 100644 => 100755 src/yacht/utils.py delete mode 100755 utils.py diff --git a/hypothesis_recovery_src.py b/hypothesis_recovery_src.py deleted file mode 100644 index 750bb403..00000000 --- a/hypothesis_recovery_src.py +++ /dev/null @@ -1,480 +0,0 @@ -import os -import sys -import numpy as np -import warnings -from scipy.stats import binom -from scipy.special import betaincinv -import pandas as pd -import zipfile -from tqdm import tqdm -from multiprocessing import Pool -import sourmash -import glob -from typing import List, Set, Tuple -from .utils import load_signature_with_ksize, decompress_all_sig_files -# Configure Loguru logger -from loguru import logger -from cov_calc import cov_calc - -warnings.filterwarnings("ignore") - - -\ -logger.remove() -logger.add( - sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO" -) - -SIG_SUFFIX = ".sig" - - -def get_organisms_with_nonzero_overlap( - manifest: pd.DataFrame, - sample_file: str, - scale: int, - ksize: int, - num_threads: int, - path_to_genome_temp_dir: str, - path_to_sample_temp_dir: str, -) -> List[str]: - """ - This function runs the sourmash multisearch to find the organisms that have non-zero overlap with the sample. - :param manifest: a dataframe with the following columns: - 'organism_name', - 'md5sum', - 'num_unique_kmers_in_genome_sketch', - 'num_total_kmers_in_genome_sketch', - 'genome_scale_factor', - 'num_exclusive_kmers_in_sample_sketch', - 'num_total_kmers_in_sample_sketch', - 'sample_scale_factor', - 'min_coverage' - :param sample_file: string (path to the sample signature file) - :param scale: int (scale factor) - :param ksize: string (size of kmer) - :param num_threads: int (number of threads to use for parallelization) - :param path_to_genome_temp_dir: string (path to the genome temporary directory generated by the training step) - :param path_to_sample_temp_dir: string (path to the sample temporary directory) - :return: a list of organism names that have non-zero overlap with the sample - """ - # run the sourmash multisearch - # prepare the input files for the sourmash multisearch - # unzip the sourmash signature file to the temporary directory - logger.info("Unzipping the sample signature zip file") - with zipfile.ZipFile(sample_file, "r") as sample_zip_file: - sample_zip_file.extractall(path_to_sample_temp_dir) - all_gz_files = glob.glob(f"{path_to_sample_temp_dir}/signatures/*.sig.gz") - # decompress all signature files - logger.info(f"Decompressing {len(all_gz_files)} .sig.gz files using {num_threads} threads.") - decompress_all_sig_files(all_gz_files, num_threads) - - sample_sig_file = pd.DataFrame( - [ - os.path.join(path_to_sample_temp_dir, "signatures", sig_file) - for sig_file in os.listdir( - os.path.join(path_to_sample_temp_dir, "signatures") - ) - ] - ) - sample_sig_file_path = os.path.join(path_to_sample_temp_dir, "sample_sig_file.txt") - sample_sig_file.to_csv(sample_sig_file_path, header=False, index=False) - - organism_sig_file = pd.DataFrame( - [ - os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX) - for md5sum in manifest["md5sum"] - ] - ) - organism_sig_file_path = os.path.join( - path_to_sample_temp_dir, "organism_sig_file.txt" - ) - organism_sig_file.to_csv(organism_sig_file_path, header=False, index=False) - - # run the sourmash multisearch - cmd = f"sourmash scripts multisearch {sample_sig_file_path} {organism_sig_file_path} -s {scale} -k {ksize} -c {num_threads} -t 0 -o {os.path.join(path_to_sample_temp_dir, 'sample_multisearch_result.csv')}" - logger.info(f"Running sourmash multisearch with command: {cmd}") - exit_code = os.system(cmd) - if exit_code != 0: - raise ValueError(f"Error running sourmash multisearch with command: {cmd}") - - # read the multisearch result, only if the file is not empty - multisearch_result_file = os.path.join(path_to_sample_temp_dir, "sample_multisearch_result.csv") - try: - multisearch_result = pd.read_csv( - multisearch_result_file, - sep=",", - header=0, - ) - except pd.errors.EmptyDataError: - print('ERROR: Multisearch file is empty. Likely there are no microorganisms in your sample, or something went wrong', flush=True) - exit(0) - - multisearch_result = multisearch_result.drop_duplicates().reset_index(drop=True) - - return multisearch_result["match_name"].to_list() - - -def get_exclusive_hashes( - manifest: pd.DataFrame, - nontrivial_organism_names: List[str], - sample_sig: sourmash.SourmashSignature, - ksize: int, - path_to_genome_temp_dir: str, -) -> Tuple[List[Tuple[int, int]], pd.DataFrame]: - """ - This function gets the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample, and - then find how many are in the sample. - :param manifest: a dataframe with the following columns: - 'organism_name', - 'md5sum', - 'num_unique_kmers_in_genome_sketch', - 'num_total_kmers_in_genome_sketch', - 'genome_scale_factor', - 'num_exclusive_kmers_in_sample_sketch', - 'num_total_kmers_in_sample_sketch', - 'sample_scale_factor', - 'min_coverage' - :param nontrivial_organism_names: a list of organism names that have non-zero overlap with the sample - :param sample_sig: the sample signature - :param ksize: int (size of kmer) - :param path_to_genome_temp_dir: string (path to the genome temporary directory generated by the training step) - :return: - a list of tuples, each tuple contains the following information: - 1. the number of unique hashes exclusive to the organism under consideration - 2. the number of unique hashes exclusive to the organism under consideration that are in the sample - a new manifest dataframe that only contains the organisms that have non-zero overlap with the sample - """ - pvalue_cutoff=0.9999999999 - min_count_thresh=3 #TODO: consider whether to change this value - - def __find_exclusive_hashes( - md5sum: str, - path_to_temp_dir: str, - ksize: int, - single_occurrence_hashes: Set[int], - ) -> Set[int]: - # load genome signature - sig = load_signature_with_ksize( - os.path.join(path_to_temp_dir, "signatures", md5sum + SIG_SUFFIX), ksize - ) - return {hash for hash in sig.minhash.hashes if hash in single_occurrence_hashes} - - # get manifest information for the organisms that have non-zero overlap with the sample - sub_manifest = manifest.loc[ - manifest["organism_name"].isin(nontrivial_organism_names), : - ].reset_index(drop=True) - organism_md5sum_list = sub_manifest["md5sum"].to_list() - - single_occurrence_hashes: Set[int] = set() # Corrected type annotation - multiple_occurrence_hashes: Set[int] = set() - for md5sum in tqdm(organism_md5sum_list, desc="Processing organism signatures"): - sig = load_signature_with_ksize( - os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), - ksize, - ) - for hash in sig.minhash.hashes: - if hash in multiple_occurrence_hashes: - continue - elif hash in single_occurrence_hashes: - single_occurrence_hashes.remove(hash) - multiple_occurrence_hashes.add(hash) - else: - single_occurrence_hashes.add(hash) - - - #print(multiple_occurrence_hashes) - - del multiple_occurrence_hashes # free up memory - - # Find hashes that are unique to each organism - logger.info("Finding hashes that are unique to each organism") - exclusive_hashes_org = [] - for md5sum in tqdm(organism_md5sum_list, desc="Finding exclusive hashes"): - exclusive_hashes_org.append( - __find_exclusive_hashes( - md5sum, path_to_genome_temp_dir, ksize, single_occurrence_hashes - ) - ) - - #print(f"Single occurrence hashes") - #print(single_occurrence_hashes) #adding this for testing - del single_occurrence_hashes # free up memory - - # Get sample hashes - sample_hashes = set(sample_sig.minhash.hashes) - - # Get sample hashes keys - sample_hashes_keys = sample_sig.minhash.hashes.keys() - #print(sample_hashes_keys) - samp_kmers_items = sample_sig.minhash.hashes.items() - samp_dict = dict(samp_kmers_items) - - stats_list = [] - for md5sum in tqdm(organism_md5sum_list, desc="Processing coverage per organism"): - sig = load_signature_with_ksize( - os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), - ksize, - ) - stats_out = cov_calc(sample_sig, sig) #location of cov_calc, which calculates effective coverage and other things according to Shaw and Yu (2024) - stats_list.append(stats_out) - - final_stats_df = pd.concat(stats_list, ignore_index=True) - - del stats_list # free up memory - - # Find hashes that are unique to each organism and in the sample - logger.info("Finding hashes that are unique to each organism and in the sample") - exclusive_hashes_info = [] - for i, exclusive_hashes in enumerate( - tqdm(exclusive_hashes_org, desc="Matching exclusive hashes with sample") - ): - #print(f"exclusive_hash", exclusive_hashes) - exclusive_hashes_info.append( - (len(exclusive_hashes), len(exclusive_hashes.intersection(sample_hashes))) - ) - - # Calculate lambda and other related coverage metrics for each organism in the sample - #logger.info("Calculate lambda for each organism in the sample") - #for i, lambda_stats in enumerate() - - #print(type(exclusive_hashes_info)) - #print(exclusive_hashes_info) - #print(type(sub_manifest)) - - columns_of_interest = [ - 'naive_ani', - 'final_est_ani', - 'final_est_cov', - 'mean_cov', - 'median_cov', - 'lambda_ci', - 'ani_ci' - ] - - # Select only those columns from the DataFrame - selected_data = final_stats_df[columns_of_interest] - - summary_stats = selected_data.describe() - - #print(sub_manifest) - print(final_stats_df) - print(summary_stats) - #print(final_stats_df['lambda_ci'].unique()) - #print(final_stats_df['ani_ci'].unique()) - - return exclusive_hashes_info, sub_manifest, final_stats_df - -def get_alt_mut_rate( - nu: int, thresh: int, ksize: int, significance: float = 0.99 -) -> float: - """ - Computes the alternative mutation rate for a given significance level. I.e. how much higher would the mutation rate - have needed to be in order to have a false positive rate of significance (since we are setting the false negative - rate to significance by design)? - :param nu: int (Number of k-mers exclusive to the organism under consideration) - :param thresh: Number of exclusive k-mers I would need to observe in order to reject the null hypothesis (i.e. - accept that the organism is present) - :param ksize: int (k-mer size) - :param significance: value between 0 and 1 expressing the desired false positive rate (and by design, the false - negative rate) - :return: float (alternative mutation rate; how much higher would the mutation rate have needed to be in order to - make FP and FN rates equal to significance) - """ - # Replace binary search with the regularized incomplete Gamma function inverse: Solve[significance == - # BetaRegularized[1 - (1 - mutCurr)^k, nu - thresh, - # 1 + thresh], mutCurr] - # per mathematica - mut = 1 - (1 - betaincinv(nu - thresh, 1 + thresh, significance)) ** (1 / ksize) - return -1.0 if np.isnan(mut) else mut - - -def single_hyp_test( - exclusive_hashes_info_org: Tuple[int, int], - ksize: int, - significance: float = 0.99, - ani_thresh: float = 0.95, - min_coverage: int = 1, -) -> Tuple[bool, float, int, int, int, int, float, float]: - """ - Performs a single hypothesis test for the presence of a genome in a metagenome. - :param exclusive_hashes_info_org: a tuple containing the following information: - 1. the number of unique hashes exclusive to this genome under consideration - 2. the number of unique hashes exclusive to this genome under consideration that are in the sample - :param ksize: int (k-mer size) - :param significance: float (significance level for the hypothesis test) - :param ani_thresh: threshold for ANI (i.e. how similar do the genomes need to be in order to be considered the same) - :param min_coverage: minimum coverage of the genome under consideration in the metagenome (float in [0, 1]) - :return: A whole bunch of stuff - """ - # get the number of unique k-mers - num_exclusive_kmers = exclusive_hashes_info_org[0] - #print(exclusive_hashes_info_org) ##printing the output of this to determine what the data structure looks like - # mutation rate - non_mut_p = (ani_thresh) ** ksize - # # assuming coverage of 1, how many unique k-mers would I need to observe in order to reject the null hypothesis? - # acceptance_threshold_wo_coverage = binom.ppf(1-significance, num_exclusive_kmers, non_mut_p) - # # what is the actual confidence of the test? - # actual_confidence_wo_coverage = 1-binom.cdf(acceptance_threshold_wo_coverage, num_exclusive_kmers, non_mut_p) - # number of unique k-mers I would see given a coverage of min_coverage - num_exclusive_kmers_coverage = int(num_exclusive_kmers * min_coverage) - # how many unique k-mers would I need to observe in order to reject the null hypothesis, - # assuming coverage of min_cov? - acceptance_threshold_with_coverage = binom.ppf( - 1 - significance, num_exclusive_kmers_coverage, non_mut_p - ) - # what is the actual confidence of the test, assuming coverage of min_cov? - actual_confidence_with_coverage = 1 - binom.cdf( - acceptance_threshold_with_coverage, num_exclusive_kmers_coverage, non_mut_p - ) - # # what is the alternative mutation rate? I.e. how much higher would the mutation rate (resp. how low of ANI) - # # have needed to be in order to have a false positive rate of significance - # # (since we are setting the false negative rate to significance by design)? - # alt_confidence_mut_rate = get_alt_mut_rate(num_exclusive_kmers, acceptance_threshold_wo_coverage, ksize, - # significance=significance) - # same as above, but assuming coverage of min_cov - alt_confidence_mut_rate_with_coverage = get_alt_mut_rate( - num_exclusive_kmers_coverage, - acceptance_threshold_with_coverage, - ksize, - significance=significance, - ) - - # How many unique k-mers do I actually see? - num_matches = exclusive_hashes_info_org[1] - #print(num_matches) #printing this for testing - # calculate the p-value considering the coverage - if num_matches <= num_exclusive_kmers_coverage: - p_val = binom.cdf(num_matches, num_exclusive_kmers_coverage, non_mut_p) - else: - p_val = 1.0 - # is the genome present? Takes coverage into account - in_sample_est = (num_matches >= acceptance_threshold_with_coverage) and ( - num_matches != 0 - ) - # return in_sample_est, p_val, num_exclusive_kmers, num_exclusive_kmers_coverage, num_matches, \ - # acceptance_threshold_wo_coverage, acceptance_threshold_with_coverage, actual_confidence_wo_coverage, \ - # actual_confidence_with_coverage, alt_confidence_mut_rate, alt_confidence_mut_rate_with_coverage - return ( - in_sample_est, - p_val, - num_exclusive_kmers, - num_exclusive_kmers_coverage, - num_matches, - acceptance_threshold_with_coverage, - actual_confidence_with_coverage, - alt_confidence_mut_rate_with_coverage, - ) - - -def hypothesis_recovery( - manifest: pd.DataFrame, - sample_info_set: Tuple[str, sourmash.SourmashSignature], - path_to_genome_temp_dir: str, - min_coverage_list: List[float], - scale: int, - ksize: int, - significance: float = 0.99, - ani_thresh: float = 0.95, - num_threads: int = 16, -): - """ - Go through each of the organisms that have non-zero overlap with the sample and perform a hypothesis test for the - presence of that organism in the sample: have we seen enough k-mers exclusive to that organism to conclude that - an organism with ANI > ani_thresh (to the one under consideration) is present in the sample? - :param manifest: a dataframe with the following columns: - 'organism_name', - 'md5sum', - 'num_unique_kmers_in_genome_sketch', - 'num_total_kmers_in_genome_sketch', - 'genome_scale_factor', - 'num_exclusive_kmers_in_sample_sketch', - 'num_total_kmers_in_sample_sketch', - 'sample_scale_factor', - 'min_coverage' - :param sample_info_set: a set of information about the sample, including the sample signature location and the sample signature object - :param path_to_genome_temp_dir: path to the genome temporary directory generated by the training step - :param min_coverage_list: a list of minimum coverage values - :param scale: scale factor - :param ksize: k-mer size - :param significance: significance level for the hypothesis test - :param ani_thresh: threshold for ANI (i.e. how similar do the genomes need to be in order to be considered the same) - :param num_threads: number of threads to use for parallelization - :return: a list of pandas dataframe with the results of the hypothesis tests based on different min_coverage values - """ - - # unpack the sample info set - sample_file, sample_sig = sample_info_set - - # create a temporary directory for the sample - sample_dir = os.path.dirname(sample_file) - sample_name = os.path.basename(sample_file).replace(".sig.zip", "") - path_to_sample_temp_dir = os.path.join( - sample_dir, f"sample_{sample_name}_intermediate_files" - ) - if os.path.exists(path_to_sample_temp_dir): - # if exists, remove it - logger.info(f"Removing existing temporary directory: {path_to_sample_temp_dir}") - os.system(f"rm -rf {path_to_sample_temp_dir}") - os.makedirs(path_to_sample_temp_dir) - - # Find the organisms that have non-zero overlap with the sample - nontrivial_organism_names = get_organisms_with_nonzero_overlap( - manifest, - sample_file, - scale, - ksize, - num_threads, - path_to_genome_temp_dir, - path_to_sample_temp_dir, - ) - - # Get the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample - exclusive_hashes_info, manifest, final_stats_df = get_exclusive_hashes( - manifest, nontrivial_organism_names, sample_sig, ksize, path_to_genome_temp_dir - ) - - # Set up the results dataframe columns - # n.b. that the output of cov_calc is not being incorporated here; instead it's being returned separately, as a pandas dataframe. - given_columns = [ - "in_sample_est", # Main output: Boolean indicating whether genome is present in sample - "p_vals", # Probability of observing this or more extreme result at ANI threshold. - "num_exclusive_kmers_to_genome", # Number of k-mers exclusive to genome - "num_exclusive_kmers_to_genome_coverage", # Number of k-mers exclusive to genome, assuming coverage of min_cov - "num_matches", # Number of k-mers exclusive to genome that are present in the sample - # 'acceptance_threshold_wo_coverage', # Acceptance threshold without adjusting for coverage - # (how many k-mers need to be present in order to reject the null hypothesis) - "acceptance_threshold_with_coverage", # Acceptance threshold with adjusting for coverage - # 'actual_confidence_wo_coverage', # Actual confidence without adjusting for coverage - "actual_confidence_with_coverage", # Actual confidence with adjusting for coverage - # 'alt_confidence_mut_rate', # Mutation rate such that at this mutation rate, false positive rate = p_val. - # Does not account for min_coverage parameter. - "alt_confidence_mut_rate_with_coverage", # same as above, but accounting for min_coverage parameter - ] - - # Using multiprocessing.Pool to parallelize the execution - manifest_list = [] - for min_coverage in tqdm(min_coverage_list, desc="Computing hypothesis recovery"): - logger.info(f"Computing hypothesis recovery for min_coverage={min_coverage}") - with Pool(processes=num_threads) as p: - params = ( - ( - exclusive_hashes_info[i], - ksize, - significance, - ani_thresh, - min_coverage, - ) - for i in range(len(exclusive_hashes_info)) - ) - results = p.starmap(single_hyp_test, params) - print(f"Finished computing all results for min_coverage value: {min_coverage}") #for testing - - # Create a pandas dataframe to store the results - results = pd.DataFrame(results, columns=given_columns) - #print(results) #for testing - - # combine the results with the manifest - manifest["min_coverage"] = min_coverage - manifest_list.append(pd.concat([manifest, results], axis=1)) - - return manifest_list, final_stats_df diff --git a/cov_calc.py b/src/yacht/cov_calc.py similarity index 100% rename from cov_calc.py rename to src/yacht/cov_calc.py diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index 3420d36f..750bb403 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -14,11 +14,12 @@ from .utils import load_signature_with_ksize, decompress_all_sig_files # Configure Loguru logger from loguru import logger +from cov_calc import cov_calc warnings.filterwarnings("ignore") - +\ logger.remove() logger.add( sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO" @@ -143,6 +144,8 @@ def get_exclusive_hashes( 2. the number of unique hashes exclusive to the organism under consideration that are in the sample a new manifest dataframe that only contains the organisms that have non-zero overlap with the sample """ + pvalue_cutoff=0.9999999999 + min_count_thresh=3 #TODO: consider whether to change this value def __find_exclusive_hashes( md5sum: str, @@ -155,7 +158,7 @@ def __find_exclusive_hashes( os.path.join(path_to_temp_dir, "signatures", md5sum + SIG_SUFFIX), ksize ) return {hash for hash in sig.minhash.hashes if hash in single_occurrence_hashes} - + # get manifest information for the organisms that have non-zero overlap with the sample sub_manifest = manifest.loc[ manifest["organism_name"].isin(nontrivial_organism_names), : @@ -177,6 +180,10 @@ def __find_exclusive_hashes( multiple_occurrence_hashes.add(hash) else: single_occurrence_hashes.add(hash) + + + #print(multiple_occurrence_hashes) + del multiple_occurrence_hashes # free up memory # Find hashes that are unique to each organism @@ -188,10 +195,32 @@ def __find_exclusive_hashes( md5sum, path_to_genome_temp_dir, ksize, single_occurrence_hashes ) ) + + #print(f"Single occurrence hashes") + #print(single_occurrence_hashes) #adding this for testing del single_occurrence_hashes # free up memory # Get sample hashes sample_hashes = set(sample_sig.minhash.hashes) + + # Get sample hashes keys + sample_hashes_keys = sample_sig.minhash.hashes.keys() + #print(sample_hashes_keys) + samp_kmers_items = sample_sig.minhash.hashes.items() + samp_dict = dict(samp_kmers_items) + + stats_list = [] + for md5sum in tqdm(organism_md5sum_list, desc="Processing coverage per organism"): + sig = load_signature_with_ksize( + os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), + ksize, + ) + stats_out = cov_calc(sample_sig, sig) #location of cov_calc, which calculates effective coverage and other things according to Shaw and Yu (2024) + stats_list.append(stats_out) + + final_stats_df = pd.concat(stats_list, ignore_index=True) + + del stats_list # free up memory # Find hashes that are unique to each organism and in the sample logger.info("Finding hashes that are unique to each organism and in the sample") @@ -199,12 +228,41 @@ def __find_exclusive_hashes( for i, exclusive_hashes in enumerate( tqdm(exclusive_hashes_org, desc="Matching exclusive hashes with sample") ): + #print(f"exclusive_hash", exclusive_hashes) exclusive_hashes_info.append( (len(exclusive_hashes), len(exclusive_hashes.intersection(sample_hashes))) ) - return exclusive_hashes_info, sub_manifest + # Calculate lambda and other related coverage metrics for each organism in the sample + #logger.info("Calculate lambda for each organism in the sample") + #for i, lambda_stats in enumerate() + + #print(type(exclusive_hashes_info)) + #print(exclusive_hashes_info) + #print(type(sub_manifest)) + + columns_of_interest = [ + 'naive_ani', + 'final_est_ani', + 'final_est_cov', + 'mean_cov', + 'median_cov', + 'lambda_ci', + 'ani_ci' + ] + + # Select only those columns from the DataFrame + selected_data = final_stats_df[columns_of_interest] + + summary_stats = selected_data.describe() + + #print(sub_manifest) + print(final_stats_df) + print(summary_stats) + #print(final_stats_df['lambda_ci'].unique()) + #print(final_stats_df['ani_ci'].unique()) + return exclusive_hashes_info, sub_manifest, final_stats_df def get_alt_mut_rate( nu: int, thresh: int, ksize: int, significance: float = 0.99 @@ -250,6 +308,7 @@ def single_hyp_test( """ # get the number of unique k-mers num_exclusive_kmers = exclusive_hashes_info_org[0] + #print(exclusive_hashes_info_org) ##printing the output of this to determine what the data structure looks like # mutation rate non_mut_p = (ani_thresh) ** ksize # # assuming coverage of 1, how many unique k-mers would I need to observe in order to reject the null hypothesis? @@ -282,6 +341,7 @@ def single_hyp_test( # How many unique k-mers do I actually see? num_matches = exclusive_hashes_info_org[1] + #print(num_matches) #printing this for testing # calculate the p-value considering the coverage if num_matches <= num_exclusive_kmers_coverage: p_val = binom.cdf(num_matches, num_exclusive_kmers_coverage, non_mut_p) @@ -369,11 +429,12 @@ def hypothesis_recovery( ) # Get the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample - exclusive_hashes_info, manifest = get_exclusive_hashes( + exclusive_hashes_info, manifest, final_stats_df = get_exclusive_hashes( manifest, nontrivial_organism_names, sample_sig, ksize, path_to_genome_temp_dir ) # Set up the results dataframe columns + # n.b. that the output of cov_calc is not being incorporated here; instead it's being returned separately, as a pandas dataframe. given_columns = [ "in_sample_est", # Main output: Boolean indicating whether genome is present in sample "p_vals", # Probability of observing this or more extreme result at ANI threshold. @@ -406,12 +467,14 @@ def hypothesis_recovery( for i in range(len(exclusive_hashes_info)) ) results = p.starmap(single_hyp_test, params) + print(f"Finished computing all results for min_coverage value: {min_coverage}") #for testing # Create a pandas dataframe to store the results results = pd.DataFrame(results, columns=given_columns) + #print(results) #for testing # combine the results with the manifest manifest["min_coverage"] = min_coverage manifest_list.append(pd.concat([manifest, results], axis=1)) - return manifest_list + return manifest_list, final_stats_df diff --git a/src/yacht/utils.py b/src/yacht/utils.py old mode 100644 new mode 100755 index beaf0988..38afcc54 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -1,14 +1,21 @@ import os import sys import numpy as np -import sourmash +from typing import List, Optional +from collections import Counter +from sourmash import load_one_signature +from collections import defaultdict from tqdm import tqdm import pandas as pd from multiprocessing import Pool from loguru import logger -from typing import Optional, List, Set, Dict, Tuple +from typing import Optional, List, Set, Dict, Tuple, Any import shutil import gzip +import math +import random +import sourmash +from dataclasses import dataclass from glob import glob # Configure Loguru logger @@ -17,16 +24,65 @@ sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO" ) -# Set up contants +# Set up constants COL_NOT_FOUND_ERROR = "Column not found: {}" FILE_LOCATION = os.path.dirname(os.path.realpath(__file__)) +# Adding two more contstants (RTR) +SAMPLE_SIZE_CUTOFF: int = 25 +PVALUE_CUTOFF: float = 0.9999999999 +ksize = 31 #Note: hard-coding this for now # Set up global variables -__version__ = "1.3.2" +__version__ = "2.0.1" GITHUB_API_URL = "https://api.github.com/repos/KoslickiLab/YACHT/contents/demo/{path}" GITHUB_RAW_URL = "https://raw.githubusercontent.com/KoslickiLab/YACHT/main/demo/{path}" BASE_URL = "https://farm.cse.ucdavis.edu/~ctbrown/sourmash-db/" -ZENODO_COMMUNITY_URL = "https://zenodo.org/api/records?q=communities:yacht&size=25" +ZENODO_COMMUNITY_URL = "https://zenodo.org/api/records/?communities=yacht&size=100" + +# A dataclass to implement something equivalent to sylph's rust-based enum implementation (AdjustStatus) +@dataclass(frozen=True) +class AdjustStatusLambda: + value: float + +class AdjustStatusHigh: + pass + +class AdjustStatusLow: + pass + +class AdjustStatusNone: + pass + +# Class for cov_calc output +@dataclass +class AniResult: + naive_ani: float + final_est_ani: float + final_est_cov: float + seq_name: str + gn_name: str + contig_name: str + mean_cov: float + median_cov: float + containment_index: Tuple[int, int] + lambda_status: AdjustStatusLambda + ani_ci: Tuple[Optional[float], Optional[float]] + lambda_ci: Tuple[Optional[float], Optional[float]] + genome_sketch: Any + rel_abund: Optional[float] + seq_abund: Optional[float] + kmers_lost: Optional[int] + +# Class for cov_calc arguments +# Note that these are being set here for now, but could be brought as CLI arguments for yacht +class _ContainArgs: + def __init__(self): + self.ci_int = True + self.ratio = True + self.mme = True + self.nb = True + self.mle = True + self.min_count_correct = 1 # Example value def load_signature_with_ksize(filename: str, ksize: int) -> sourmash.SourmashSignature: """ @@ -507,3 +563,298 @@ def decompress_all_sig_files(sig_files: List[str], num_threads: int) -> None: p.map(_decompress_and_remove, sig_files) logger.info("All .sig.gz files have been decompressed.") + +def load_one_sig(sig_path: str, ksize: int): + """ + Imports a sourmash signature file using the path and the ksize. + """ + loaded_sig = load_one_signature(sig_path, ksize, select_moltype='DNA').to_mutable() + + logger.info("The sig file has been loaded." + ) + return(loaded_sig) + +def newton_raphson(ratio: float, mean: float): + """ + Shaw and Yu (2024)'s implmentation of Newton-Raphson use to assist in the calculation of lambda. + """ + curr = mean / (1 - ratio) + #print(1-mean) + #print(1-ratio) + for _ in range(1000): #iterates to converge on an approximation for the root + t1 = (1 - ratio) * curr + e_curr = math.exp(-curr) + t2 = mean * (1 - e_curr) + t3 = 1 - ratio + t4 = mean * e_curr + curr = curr - (t1 - t2) / (t3 - t4) + return curr + +def mle_zip(full_covs: list[int], _k: float): + """ + Maximum likelihood estimator for the zero-inflated Poisson (ZIP) distribution from Shaw and Yu (2024) + """ + n_zero = 0 + count_set = Counter() #creating a new counter set + + for x in full_covs: + if x == 0: + n_zero += 1 + else: + count_set[x] +=1 + + if len(count_set) == 1: #If no info for inference, retuns None + return None + + if len(full_covs) - n_zero < SAMPLE_SIZE_CUTOFF: + return None + + mean = np.mean(full_covs) + nr_input = num_zero/len(full_covs) + lambda_out = newton_raphson(nr_input, mean) + + if lambda_out < 0 or math.isnan(lambda_out): + lambda_ret = None + else: + lambda_ret = lambda_out + return lambda_ret + +def variance(data: str(int)): + """ + An internal function that calculates the variance, which is the average of the squared differences of all values from the mean + """ + if len(data) == 0: + return None + mean = np.mean(data) + var = 0 + for x in data: + var += (float(x) - mean) * (float(x) - mean) + var_out = float(var / len(data)) + return var_out + +def ratio_lambda(full_covs: list[int], min_count_correct): + """ + Estimating lambda according to Shaw and Yu (2024) using the ratio method + """ + n_zero = 0 + count_map = defaultdict(int) # Creates an empty dictionary for integer values equivalent to FxHashMap::default() + for x in full_covs: + if x == 0: + n_zero += 1 + else: + count_map[x] +=1 + if len(count_map) == 1: # Absent info for inference, returns None. + return None + + if (len(full_covs) - n_zero) < SAMPLE_SIZE_CUTOFF: + return None + else: + sort_vec = [(value, key) for key, value in count_map.items()] + sort_vec.sort(reverse=True) #sorting the vector in reverse order + most_ind = sort_vec[0][1] + if (most_ind + 1) not in count_map: + return None + + count_p1 = count_map[(most_ind + 1)] + count = count_map[most_ind] + + if count_p1 < min_count_correct or count < min_count_correct: + return None + + lambda_out = count_p1 / (count * (most_ind + 1)) + return lambda_out + +def r_moments_lambda(m: float, v: float, lambda_out: float): + """ + Internal function used used in the calculation of lambda using ratio from moments + """ + result = lambda_out / (v - 1 + lambda_out + m) + return result + +def ratio_calc(val: float, r: float, lambda_out: float): + """ + Internal function used used in the calculation of lambda using ratio from moments + """ + if (r < 100): + return gamma(r + val + 1) / (val + 1) / gamma(r + val) * lambda_out / (r + lambda_out) + else: + return (r + val + 1) / (val + 1) * lambda_out / (r + lambda_out) + +def ratio_from_moments_lambda(val: float, lambda_out: float, m: float, v: float): + """ + Function that calculates lambda using the ratio from moments formula + """ + r = r_moments_lambda(m, v, lambda_out) + if r < 0: + return None + rat_return = ratio_calc(val, r, lambda_out) + return rat_return + +def mme_lambda(full_covs: list[int]) -> Optional[float]: + """ + Calculates the "method of moments" estimator for the lambda parameter + from a list of coverage values. + """ + num_zero = 0 + count_set = set() + + for x in full_covs: + if x == 0: + num_zero += 1 + else: + count_set.add(x) + if len(count_set) == 1: + return None + # Does the number of non-zero observations meet the cutoff threshold? + if len(full_covs) - num_zero < SAMPLE_SIZE_CUTOFF: + return None + # Calculating mean and variance + try: + mean_val = np.mean(full_covs) + variance_val = variance(full_covs) + except ValueError: + return None + # Calculating lambda using the MME formula + lambda_val = variance_val / mean_val + mean_val - 1.0 + + # Ensuring lambda is non-negative + if lambda_val < 0.0: + return None + else: + return lambda_val + +def binary_search_lambda(full_covs: list[int]): + if len(full_covs) == 0: + return None + m = mean(full_covs) + v = var(full_covs) + nonzero = 0 + ones = 0 + twos = 0 + + for x in full_covs: + if x != 0: + _nonzero += 1 + if x == 1: + ones += 1 + elif x == 2: + twos += 1 + + ratio_est = float(twos) / float(ones) + + left = float(max(0.003, m - 2)) + right = m + 5 + endpoints = ("start", "end") + left, right = endpoints + best = None + best_val = 10000 + for i in range(10000): + test = (endpoints - endpoints)/10000 * float(i) + endpoints + proposed = ratio_from_moments_lambda(1, test, m, v); + if proposed is not None: + p = proposed - ratio_est + if abs(p) < best_val: + best_val = abs(p) + best = test + + if best == None: + return None + #consider putting in a RaiseError statement here + r = r_moments_lambda(m, v, best) + # removed debugging code from sylph (see inference.rs for reference) + return best + +def bootstrap_interval(covs_full: list[int], k: float, args: _ContainArgs): + """ + This function calculates bootstrap confidence intervals for ANI and lambda. + """ + if args.ci_int == False: + return (None, None, None, None) + + #logger.info("Bootstrap interval") + #print(f"Bootstrap interval") #for testing #TODO 12/3: look into whether/where this function is being activated + num_samp = len(covs_full) + iters = 100 + res_ani = [] + res_lambda = [] + + for _ in range(iters): + rand_vec = [] + for _ in range(num_samp): + rand_vec.append(random.choice(covs_full)) + if args.ratio: + lambda_val = ratio_lambda(rand_vec, args.min_count_correct) + elif args.mme: + lambda_val = mme_lambda(rand_vec) + elif args.nb: + lambda_val = binary_search_lambda(rand_vec) + elif args.mle: + lambda_val = mle_zip(rand_vec, ksize) + else: + lambda_val = ratio_lambda(rand_vec, args.min_count_correct) + + #print(f"lambda_val is:") #for testing + #print(lambda_val) + + ani_val = ani_from_lambda(lambda_val, np.mean(rand_vec), ksize, rand_vec) + + if ani_val is not None and lambda_val is not None: + if not np.isnan(ani_val) and not np.isnan(lambda_val): + res_ani.append(ani_val) + res_lambda.append(lambda_val) + + res_ani.sort() + res_lambda.sort() + + if len(res_ani) < 50: + return (None, None, None, None) + + suc = len(res_ani) + low_ani = res_ani[suc * 5 // 100] + high_ani = res_ani[suc * 95 // 100] + low_lambda = res_lambda[suc * 5 // 100] + high_lambda = res_lambda[suc * 95 // 100] + + #print(f"Bootstrap interval") #for testing + #print(low_ani, high_ani, low_lambda, high_lambda) + return (low_ani, high_ani, low_lambda, high_lambda) + +def ani_from_lambda(lambda_val, lam_mean, k_value, full_cov): + """ + Calculates an adjusted ani value + Args: + lambda_val: An optional float used to calculate ani + lam_mean: A float value (unused in the original logic). + k: A float value used as the inverse exponent. + full_cov: A list of integers to analyze for non-zero counts. + + Returns: + An optional float representing the calculated adjusted index 'ani', + or None if the input lambda is None, or if 'ani' is negative or NaN. + """ + if lambda_val == None: + return None + contain_count = 0 + zero_count = 0 + for x in full_cov: + if x != 0: + contain_count += 1 + + else: + zero_count += 1 + + if not full_cov: + return None + + adj_index = contain_count / (1.0 - math.exp(-lambda_val)) / len(full_cov) + + # Calculating ani using math.pow for clarity (and to maintain type correctness) + ani = math.pow(adj_index, 1.0 / k_value) + + if ani < 0.0 or math.isnan(ani): + ret_ani = None + + else: + ret_ani = ani + + return ret_ani \ No newline at end of file diff --git a/utils.py b/utils.py deleted file mode 100755 index 38afcc54..00000000 --- a/utils.py +++ /dev/null @@ -1,860 +0,0 @@ -import os -import sys -import numpy as np -from typing import List, Optional -from collections import Counter -from sourmash import load_one_signature -from collections import defaultdict -from tqdm import tqdm -import pandas as pd -from multiprocessing import Pool -from loguru import logger -from typing import Optional, List, Set, Dict, Tuple, Any -import shutil -import gzip -import math -import random -import sourmash -from dataclasses import dataclass -from glob import glob - -# Configure Loguru logger -logger.remove() -logger.add( - sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO" -) - -# Set up constants -COL_NOT_FOUND_ERROR = "Column not found: {}" -FILE_LOCATION = os.path.dirname(os.path.realpath(__file__)) -# Adding two more contstants (RTR) -SAMPLE_SIZE_CUTOFF: int = 25 -PVALUE_CUTOFF: float = 0.9999999999 -ksize = 31 #Note: hard-coding this for now - -# Set up global variables -__version__ = "2.0.1" -GITHUB_API_URL = "https://api.github.com/repos/KoslickiLab/YACHT/contents/demo/{path}" -GITHUB_RAW_URL = "https://raw.githubusercontent.com/KoslickiLab/YACHT/main/demo/{path}" -BASE_URL = "https://farm.cse.ucdavis.edu/~ctbrown/sourmash-db/" -ZENODO_COMMUNITY_URL = "https://zenodo.org/api/records/?communities=yacht&size=100" - -# A dataclass to implement something equivalent to sylph's rust-based enum implementation (AdjustStatus) -@dataclass(frozen=True) -class AdjustStatusLambda: - value: float - -class AdjustStatusHigh: - pass - -class AdjustStatusLow: - pass - -class AdjustStatusNone: - pass - -# Class for cov_calc output -@dataclass -class AniResult: - naive_ani: float - final_est_ani: float - final_est_cov: float - seq_name: str - gn_name: str - contig_name: str - mean_cov: float - median_cov: float - containment_index: Tuple[int, int] - lambda_status: AdjustStatusLambda - ani_ci: Tuple[Optional[float], Optional[float]] - lambda_ci: Tuple[Optional[float], Optional[float]] - genome_sketch: Any - rel_abund: Optional[float] - seq_abund: Optional[float] - kmers_lost: Optional[int] - -# Class for cov_calc arguments -# Note that these are being set here for now, but could be brought as CLI arguments for yacht -class _ContainArgs: - def __init__(self): - self.ci_int = True - self.ratio = True - self.mme = True - self.nb = True - self.mle = True - self.min_count_correct = 1 # Example value - -def load_signature_with_ksize(filename: str, ksize: int) -> sourmash.SourmashSignature: - """ - Helper function that loads the signature for a given kmer size from the provided signature file. - Filename should point to a sourmash signature file. Raises exception if given kmer size is not present in the file. - This is a wrapper of sourmash.load_file_as_signatures, and accept all types of format: .sig, .sig.zip, .sbt, .lca, and .sqldb. - However, this function specifically ask for 1 signature so lca format is not appropriate; as of sourmash v4.8, sqldb doesn't accept "abund" parameter for signatures. - :param filename: string (location of the signature file of any format: .sig, .sig.zip, .sbt, .lca, and .sqldb) - :param ksize: int (size of kmer) - :return: sourmash signature - """ - # Take the first sample signature with the given kmer size - sketches = list(sourmash.load_file_as_signatures(filename, ksize=ksize)) - if len(sketches) != 1: - raise ValueError( - f"Expected exactly one signature with ksize {ksize} in {filename}, found {len(sketches)}" - ) - if len(sketches[0].minhash.hashes) == 0: - raise ValueError( - "Empty sketch in signature. This may be due to too high of a scale factor, please reduce it, eg. --scaled=1, and try again." - ) - return sketches[0] - - -def get_num_kmers( - minhash_mean_abundance: Optional[float], - minhash_hashes_len: int, - minhash_scaled: int, - scale: bool = True, -) -> int: - """ - Helper function that estimates the total number of kmers in a given sample. - :param minhash_mean_abundance: float or None (mean abundance of the signature) - :param minhash_hashes_len: int (number of hashes in the signature) - :param minhash_scaled: int (scale factor of the signature) - :param scale: bool (whether to scale the number of kmers by the scale factor) - :return: int (estimated total number of kmers) - """ - # Abundances may not have been kept, in which case, just use 1 - if minhash_mean_abundance: - num_kmers = minhash_mean_abundance * minhash_hashes_len - else: - num_kmers = minhash_hashes_len - if scale: - num_kmers *= minhash_scaled - return int(np.round(num_kmers)) - - -def check_file_existence(file_path: str, error_description: str) -> None: - """ - Helper function that checks if a file exists. If not, raises a ValueError with the given error description. - :param file_path: string (location of the file) - :param error_description: string (description of the error) - :return: None - """ - if not os.path.exists(file_path): - raise ValueError(error_description) - - -def get_info_from_single_sig( - sig_file: str, ksize: int -) -> Tuple[str, str, float, int, int]: - """ - Helper function that gets signature information (raw file path, name, md5sum, minhash mean abundance, minhash_hashes_len, minhash scaled) from a single sourmash signature file. - :param sig_file: string (location of the signature file with .sig.gz format) - :param ksize: int (size of kmer) - :return: tuple (name, md5sum, minhash mean abundance, minhash_hashes_len, minhash scaled) - """ - try: - sig = load_signature_with_ksize(sig_file, ksize) - return ( - sig_file, - sig.name, - sig.md5sum(), - sig.minhash.mean_abundance, - len(sig.minhash.hashes), - sig.minhash.scaled, - ) - except: - logger.warning(f"CANNOT extract the relevant info from the signature file: {sig_file}") - return None - -def run_yacht_train_core( - num_threads: int, ani_thresh: float, ksize: int, path_to_temp_dir: str, sig_info_dict: Dict[str, Tuple[str, float, int, int, str]], num_genome_threshold: int = 1000000 -) -> Dict[str, List[str]]: - """ - Helper function that runs the cpp script developed by Mahmudur Rahman Hera to find the closely related genomes with ANI > ani_thresh from the reference database, - then remove them, and generate a dataframe with the selected genomes. - :param num_threads: int (number of threads to use) - :param ani_thresh: float (threshold for ANI, below which we consider two organisms to be distinct) - :param ksize: int (size of kmer) - :param path_to_temp_dir: string (path to the folder to store the intermediate files) - :return: a dataframe containing the selected reference signature information - """ - - # run Mahmudur's cpp for genome comparison - # save signature files to a text file - sig_files = pd.DataFrame( - [ - os.path.join(path_to_temp_dir, "signatures", file) - for file in os.listdir(os.path.join(path_to_temp_dir, "signatures")) - ] - ) - sig_files_path = os.path.join(path_to_temp_dir, "training_sig_files.tsv") - sig_files.to_csv(sig_files_path, header=False, index=False) - - # convert ani threshold to containment threshold - containment_thresh = ani_thresh**ksize - total_sig_files = len(sig_files) - if total_sig_files <= num_genome_threshold: - passes = 1 - else: - passes = int(total_sig_files / num_genome_threshold) + 1 - cmd = f"{FILE_LOCATION}/run_yacht_train_core -t {num_threads} -c {containment_thresh} -p {passes} {sig_files_path} {path_to_temp_dir} {os.path.join(path_to_temp_dir, 'selected_result.tsv')}" - logger.info(f"Running comparison algorithm with command: {cmd}") - exit_code = os.system(cmd) - if exit_code != 0: - raise ValueError(f"Error running comparison algorithm with command: {cmd}") - - # move all split comparison files to a single foldr - os.makedirs(os.path.join(path_to_temp_dir, "comparison_files"), exist_ok=True) - for file in glob(os.path.join(path_to_temp_dir, "*.txt")): - shutil.move(file, os.path.join(path_to_temp_dir, "comparison_files")) - - # get info from the signature files of selected genomes - selected_sig_files = pd.read_csv(os.path.join(path_to_temp_dir, 'selected_result.tsv'), sep="\t", header=None) - selected_sig_files = selected_sig_files[0].to_list() - - # get the mapping from signature file name to genome name - mapping = {sig_info_dict[name][-1]:name for name in sig_info_dict} - selected_genome_names_set = set([mapping[sig_file_path] for sig_file_path in selected_sig_files]) - - # remove the close related organisms from the reference genome list - manifest_df = [] - for sig_name, ( - md5sum, - minhash_mean_abundance, - minhash_hashes_len, - minhash_scaled, - _ - ) in tqdm(sig_info_dict.items(), desc="Removing close related organisms from the reference genome list"): - if sig_name in selected_genome_names_set: - manifest_df.append( - ( - sig_name, - md5sum, - minhash_hashes_len, - get_num_kmers( - minhash_mean_abundance, - minhash_hashes_len, - minhash_scaled, - False, - ), - minhash_scaled, - ) - ) - manifest_df = pd.DataFrame( - manifest_df, - columns=[ - "organism_name", - "md5sum", - "num_unique_kmers_in_genome_sketch", - "num_total_kmers_in_genome_sketch", - "genome_scale_factor", - ], - ) - - return manifest_df - - - -def collect_signature_info( - num_threads: int, ksize: int, path_to_temp_dir: str -) -> Dict[str, Tuple[str, float, int, int]]: - """ - Helper function that collects signature information (raw file path, name, md5sum, minhash mean abundance, minhash_hashes_len, minhash scaled) from a sourmash signature database. - :param num_threads: int (number of threads to use) - :param ksize: int (size of kmer) - :param path_to_temp_dir: string (path to the folder to store the intermediate files) - :return: a dictionary mapping signature name to a tuple (md5sum, minhash mean abundance, minhash_hashes_len, minhash scaled, raw file path) - """ - ## extract in parallel - with Pool(num_threads) as p: - signatures = p.starmap( - get_info_from_single_sig, - [ - (os.path.join(path_to_temp_dir, "signatures", file), ksize) - for file in os.listdir(os.path.join(path_to_temp_dir, "signatures")) - ], - ) - - return {sig[1]: (sig[2], sig[3], sig[4], sig[5], sig[0]) for sig in tqdm(signatures) if sig} - - -class Prediction: - def __init__(self): - self._rank = None - self._taxid = None - self._percentage = None - self._taxpath = None - self._taxpathsn = None - - @property - def rank(self): - return self._rank - - @rank.setter - def rank(self, value): - self._rank = value - - @property - def taxid(self): - return self._taxid - - @taxid.setter - def taxid(self, value): - self._taxid = value - - @property - def percentage(self): - return self._percentage - - @percentage.setter - def percentage(self, value): - self._percentage = value - - @property - def taxpath(self): - return self._taxpath - - @taxpath.setter - def taxpath(self, value): - self._taxpath = value - - @property - def taxpathsn(self): - return self._taxpathsn - - @taxpathsn.setter - def taxpathsn(self, value): - self._taxpathsn = value - - def get_dict(self): - return self.__dict__ - - def get_pretty_dict(self): - return { - property.split("_")[1]: value - for property, value in self.__dict__.items() - if property.startswith("_") - } - - def get_metadata(self): - return { - "rank": self._rank, - "taxpath": self._taxpath, - "taxpathsn": self._taxpathsn, - } - - -def get_column_indices( - column_name_to_index: Dict[str, int] -) -> Tuple[int, int, int, int, Optional[int]]: - """ - (thanks to https://github.com/CAMI-challenge/OPAL, this function is modified get_column_indices from its load_data.py) - Helper function that gets the column indices for the following columns: TAXID, RANK, PERCENTAGE, TAXPATH, TAXPATHSN - :param column_name_to_index: dictionary mapping column name to column index - :return: indices for TAXID, RANK, PERCENTAGE, TAXPATH, TAXPATHSN - """ - # Assuming all other indices are mandatory and only index_taxpathsn can be optional - index_taxpathsn: Optional[int] = None # Correctly annotated to allow None - - if "TAXID" not in column_name_to_index: - logger.error(COL_NOT_FOUND_ERROR.format("TAXID")) - raise RuntimeError - if "RANK" not in column_name_to_index: - logger.error(COL_NOT_FOUND_ERROR.format("RANK")) - raise RuntimeError - if "PERCENTAGE" not in column_name_to_index: - logger.error(COL_NOT_FOUND_ERROR.format("PERCENTAGE")) - raise RuntimeError - if "TAXPATH" not in column_name_to_index: - logger.error(COL_NOT_FOUND_ERROR.format("TAXPATH")) - raise RuntimeError - index_taxid = column_name_to_index["TAXID"] - index_rank = column_name_to_index["RANK"] - index_percentage = column_name_to_index["PERCENTAGE"] - index_taxpath = column_name_to_index["TAXPATH"] - if "TAXPATHSN" in column_name_to_index: - index_taxpathsn = column_name_to_index["TAXPATHSN"] - else: - index_taxpathsn = None - return index_rank, index_taxid, index_percentage, index_taxpath, index_taxpathsn - - -def get_cami_profile( - cami_content: List[str], -) -> List[Tuple[str, Dict[str, str], List[Prediction]]]: - header: Dict[str, str] = {} # Dictionary mapping strings to strings - profile: List[Prediction] = [] # List of Prediction objects - predictions_dict: Dict[ - str, Prediction - ] = {} # Mapping from string to Prediction object - """ - (thanks to https://github.com/CAMI-challenge/OPAL, this function is modified open_profile_from_tsv from its load_data.py) - Helper function that opens a CAMI profile file and returns sample profiling information. - params:cami_content: list of strings (lines of the CAMI profile file) - return: list of tuples (sample_id, header, profile) - """ - header = {} - column_name_to_index = {} - profile = [] - samples_list = [] - predictions_dict = {} - reading_data = False - got_column_indices = False - - for line in cami_content: - if len(line.strip()) == 0 or line.startswith("#"): - continue - line = line.rstrip("\n") - - # parse header with column indices - if line.startswith("@@"): - for index, column_name in enumerate(line[2:].split("\t")): - column_name_to_index[column_name] = index - ( - index_rank, - index_taxid, - index_percentage, - index_taxpath, - index_taxpathsn, - ) = get_column_indices(column_name_to_index) - got_column_indices = True - reading_data = False - continue - - # parse header with metadata - if line.startswith("@"): - # if last line contained sample data and new header starts, store profile for sample - if reading_data: - if "SAMPLEID" in header and "VERSION" in header and "RANKS" in header: - if len(profile) > 0: - samples_list.append((header["SAMPLEID"], header, profile)) - profile = [] - predictions_dict = {} - else: - logger.error( - "Header is incomplete. Check if the header of each sample contains at least SAMPLEID, VERSION, and RANKS." - ) - raise RuntimeError - header = {} - reading_data = False - got_column_indices = False - key, value = line[1:].split(":", 1) - header[key.upper()] = value.strip() - continue - - if not got_column_indices: - logger.error( - "Header line starting with @@ is missing or at wrong position." - ) - raise RuntimeError - - reading_data = True - row_data = line.split("\t") - - taxid = row_data[index_taxid] - # if there is already a prediction for taxon, only sum abundance - if taxid in predictions_dict: - prediction = predictions_dict[taxid] - prediction.percentage += float(row_data[index_percentage]) - else: - if int(float(row_data[index_percentage])) == 0: - continue - prediction = Prediction() - predictions_dict[taxid] = prediction - prediction.taxid = row_data[index_taxid] - prediction.rank = row_data[index_rank] - prediction.percentage = float(row_data[index_percentage]) - prediction.taxpath = row_data[index_taxpath] - if isinstance(index_taxpathsn, int): - prediction.taxpathsn = row_data[index_taxpathsn] - else: - prediction.taxpathsn = None - profile.append(prediction) - - # store profile for last sample - if "SAMPLEID" in header and "VERSION" in header and "RANKS" in header: - if reading_data and len(profile) > 0: - samples_list.append((header["SAMPLEID"], header, profile)) - else: - logger.error( - "Header is incomplete. Check if the header of each sample contains at least SAMPLEID, VERSION, and RANKS." - ) - raise RuntimeError - - return samples_list - - -def create_output_folder(outfolder): - """ - Helper function that creates the output folder if it does not exist. - :param outfolder: location of output folder - :return: None - """ - if not os.path.exists(outfolder): - logger.info(f"Creating output folder: {outfolder}") - os.makedirs(outfolder) - - -def check_download_args(args, db_type): - """ - Helper function that checks if the input arguments are valid. - :param args: input arguments - :param db_type: type of database options: "pretrained" or "default" - :return: None - """ - if args.database not in ["genbank", "gtdb"]: - logger.error( - f"Invalid database: {args.database}. Now only support genbank and gtdb." - ) - sys.exit(1) - - if args.k not in [21, 31, 51]: - logger.error(f"Invalid k: {args.k}. Now only support 21, 31, and 51.") - sys.exit(1) - - if args.database == "genbank": - if args.ncbi_organism is None: - logger.warning( - "No NCBI organism specified using parameter --ncbi_organism. Using the default: bacteria" - ) - args.ncbi_organism = "bacteria" - - if args.ncbi_organism not in [ - "archaea", - "bacteria", - "fungi", - "virus", - "protozoa", - ]: - logger.error( - f"Invalid NCBI organism: {args.ncbi_organism}. Now only support archaea, bacteria, fungi, virus, and protozoa." - ) - sys.exit(1) - - if db_type == "pretrained" and args.ncbi_organism == "virus": - logger.error("We now haven't supported for virus database.") - sys.exit(1) - - -def _decompress_and_remove(file_path: str) -> None: - """ - Decompresses a GZIP-compressed file and removes the original compressed file. - :param file_path: The path to the .sig.gz file that needs to be decompressed and deleted. - :return: None - """ - try: - output_filename = os.path.splitext(file_path)[0] - with gzip.open(file_path, 'rb') as f_in: - with open(output_filename, 'wb') as f_out: - f_out.write(f_in.read()) - - os.remove(file_path) - - except Exception as e: - logger.info(f"Failed to process {file_path}: {e}") - -def decompress_all_sig_files(sig_files: List[str], num_threads: int) -> None: - """ - Decompresses all .sig.gz files in the list using multiple threads. - :param sig_files: List of .sig.gz files that need to be decompressed. - :param num_threads: Number of threads to use for decompression. - :return: None - """ - with Pool(num_threads) as p: - p.map(_decompress_and_remove, sig_files) - - logger.info("All .sig.gz files have been decompressed.") - -def load_one_sig(sig_path: str, ksize: int): - """ - Imports a sourmash signature file using the path and the ksize. - """ - loaded_sig = load_one_signature(sig_path, ksize, select_moltype='DNA').to_mutable() - - logger.info("The sig file has been loaded." - ) - return(loaded_sig) - -def newton_raphson(ratio: float, mean: float): - """ - Shaw and Yu (2024)'s implmentation of Newton-Raphson use to assist in the calculation of lambda. - """ - curr = mean / (1 - ratio) - #print(1-mean) - #print(1-ratio) - for _ in range(1000): #iterates to converge on an approximation for the root - t1 = (1 - ratio) * curr - e_curr = math.exp(-curr) - t2 = mean * (1 - e_curr) - t3 = 1 - ratio - t4 = mean * e_curr - curr = curr - (t1 - t2) / (t3 - t4) - return curr - -def mle_zip(full_covs: list[int], _k: float): - """ - Maximum likelihood estimator for the zero-inflated Poisson (ZIP) distribution from Shaw and Yu (2024) - """ - n_zero = 0 - count_set = Counter() #creating a new counter set - - for x in full_covs: - if x == 0: - n_zero += 1 - else: - count_set[x] +=1 - - if len(count_set) == 1: #If no info for inference, retuns None - return None - - if len(full_covs) - n_zero < SAMPLE_SIZE_CUTOFF: - return None - - mean = np.mean(full_covs) - nr_input = num_zero/len(full_covs) - lambda_out = newton_raphson(nr_input, mean) - - if lambda_out < 0 or math.isnan(lambda_out): - lambda_ret = None - else: - lambda_ret = lambda_out - return lambda_ret - -def variance(data: str(int)): - """ - An internal function that calculates the variance, which is the average of the squared differences of all values from the mean - """ - if len(data) == 0: - return None - mean = np.mean(data) - var = 0 - for x in data: - var += (float(x) - mean) * (float(x) - mean) - var_out = float(var / len(data)) - return var_out - -def ratio_lambda(full_covs: list[int], min_count_correct): - """ - Estimating lambda according to Shaw and Yu (2024) using the ratio method - """ - n_zero = 0 - count_map = defaultdict(int) # Creates an empty dictionary for integer values equivalent to FxHashMap::default() - for x in full_covs: - if x == 0: - n_zero += 1 - else: - count_map[x] +=1 - if len(count_map) == 1: # Absent info for inference, returns None. - return None - - if (len(full_covs) - n_zero) < SAMPLE_SIZE_CUTOFF: - return None - else: - sort_vec = [(value, key) for key, value in count_map.items()] - sort_vec.sort(reverse=True) #sorting the vector in reverse order - most_ind = sort_vec[0][1] - if (most_ind + 1) not in count_map: - return None - - count_p1 = count_map[(most_ind + 1)] - count = count_map[most_ind] - - if count_p1 < min_count_correct or count < min_count_correct: - return None - - lambda_out = count_p1 / (count * (most_ind + 1)) - return lambda_out - -def r_moments_lambda(m: float, v: float, lambda_out: float): - """ - Internal function used used in the calculation of lambda using ratio from moments - """ - result = lambda_out / (v - 1 + lambda_out + m) - return result - -def ratio_calc(val: float, r: float, lambda_out: float): - """ - Internal function used used in the calculation of lambda using ratio from moments - """ - if (r < 100): - return gamma(r + val + 1) / (val + 1) / gamma(r + val) * lambda_out / (r + lambda_out) - else: - return (r + val + 1) / (val + 1) * lambda_out / (r + lambda_out) - -def ratio_from_moments_lambda(val: float, lambda_out: float, m: float, v: float): - """ - Function that calculates lambda using the ratio from moments formula - """ - r = r_moments_lambda(m, v, lambda_out) - if r < 0: - return None - rat_return = ratio_calc(val, r, lambda_out) - return rat_return - -def mme_lambda(full_covs: list[int]) -> Optional[float]: - """ - Calculates the "method of moments" estimator for the lambda parameter - from a list of coverage values. - """ - num_zero = 0 - count_set = set() - - for x in full_covs: - if x == 0: - num_zero += 1 - else: - count_set.add(x) - if len(count_set) == 1: - return None - # Does the number of non-zero observations meet the cutoff threshold? - if len(full_covs) - num_zero < SAMPLE_SIZE_CUTOFF: - return None - # Calculating mean and variance - try: - mean_val = np.mean(full_covs) - variance_val = variance(full_covs) - except ValueError: - return None - # Calculating lambda using the MME formula - lambda_val = variance_val / mean_val + mean_val - 1.0 - - # Ensuring lambda is non-negative - if lambda_val < 0.0: - return None - else: - return lambda_val - -def binary_search_lambda(full_covs: list[int]): - if len(full_covs) == 0: - return None - m = mean(full_covs) - v = var(full_covs) - nonzero = 0 - ones = 0 - twos = 0 - - for x in full_covs: - if x != 0: - _nonzero += 1 - if x == 1: - ones += 1 - elif x == 2: - twos += 1 - - ratio_est = float(twos) / float(ones) - - left = float(max(0.003, m - 2)) - right = m + 5 - endpoints = ("start", "end") - left, right = endpoints - best = None - best_val = 10000 - for i in range(10000): - test = (endpoints - endpoints)/10000 * float(i) + endpoints - proposed = ratio_from_moments_lambda(1, test, m, v); - if proposed is not None: - p = proposed - ratio_est - if abs(p) < best_val: - best_val = abs(p) - best = test - - if best == None: - return None - #consider putting in a RaiseError statement here - r = r_moments_lambda(m, v, best) - # removed debugging code from sylph (see inference.rs for reference) - return best - -def bootstrap_interval(covs_full: list[int], k: float, args: _ContainArgs): - """ - This function calculates bootstrap confidence intervals for ANI and lambda. - """ - if args.ci_int == False: - return (None, None, None, None) - - #logger.info("Bootstrap interval") - #print(f"Bootstrap interval") #for testing #TODO 12/3: look into whether/where this function is being activated - num_samp = len(covs_full) - iters = 100 - res_ani = [] - res_lambda = [] - - for _ in range(iters): - rand_vec = [] - for _ in range(num_samp): - rand_vec.append(random.choice(covs_full)) - if args.ratio: - lambda_val = ratio_lambda(rand_vec, args.min_count_correct) - elif args.mme: - lambda_val = mme_lambda(rand_vec) - elif args.nb: - lambda_val = binary_search_lambda(rand_vec) - elif args.mle: - lambda_val = mle_zip(rand_vec, ksize) - else: - lambda_val = ratio_lambda(rand_vec, args.min_count_correct) - - #print(f"lambda_val is:") #for testing - #print(lambda_val) - - ani_val = ani_from_lambda(lambda_val, np.mean(rand_vec), ksize, rand_vec) - - if ani_val is not None and lambda_val is not None: - if not np.isnan(ani_val) and not np.isnan(lambda_val): - res_ani.append(ani_val) - res_lambda.append(lambda_val) - - res_ani.sort() - res_lambda.sort() - - if len(res_ani) < 50: - return (None, None, None, None) - - suc = len(res_ani) - low_ani = res_ani[suc * 5 // 100] - high_ani = res_ani[suc * 95 // 100] - low_lambda = res_lambda[suc * 5 // 100] - high_lambda = res_lambda[suc * 95 // 100] - - #print(f"Bootstrap interval") #for testing - #print(low_ani, high_ani, low_lambda, high_lambda) - return (low_ani, high_ani, low_lambda, high_lambda) - -def ani_from_lambda(lambda_val, lam_mean, k_value, full_cov): - """ - Calculates an adjusted ani value - Args: - lambda_val: An optional float used to calculate ani - lam_mean: A float value (unused in the original logic). - k: A float value used as the inverse exponent. - full_cov: A list of integers to analyze for non-zero counts. - - Returns: - An optional float representing the calculated adjusted index 'ani', - or None if the input lambda is None, or if 'ani' is negative or NaN. - """ - if lambda_val == None: - return None - contain_count = 0 - zero_count = 0 - for x in full_cov: - if x != 0: - contain_count += 1 - - else: - zero_count += 1 - - if not full_cov: - return None - - adj_index = contain_count / (1.0 - math.exp(-lambda_val)) / len(full_cov) - - # Calculating ani using math.pow for clarity (and to maintain type correctness) - ani = math.pow(adj_index, 1.0 / k_value) - - if ani < 0.0 or math.isnan(ani): - ret_ani = None - - else: - ret_ani = ani - - return ret_ani \ No newline at end of file From 3b5a37029988f6f8c1e81714c86aca72e2552220 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 5 Dec 2025 15:28:05 -0500 Subject: [PATCH 03/41] Removed remaining print statements throughout. --- src/yacht/cov_calc.py | 44 +++++----------------------- src/yacht/hypothesis_recovery_src.py | 16 ---------- src/yacht/utils.py | 10 +------ 3 files changed, 8 insertions(+), 62 deletions(-) diff --git a/src/yacht/cov_calc.py b/src/yacht/cov_calc.py index 2d5ea5f5..64a508b6 100644 --- a/src/yacht/cov_calc.py +++ b/src/yacht/cov_calc.py @@ -54,9 +54,7 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma covs = [] contain_count = 0 for kmer in gn_hashes: - #print(kmer) if kmer in sample_hashes_keys: - #print(f"Overlap") if samp_dict[kmer] == 0: continue contain_count += 1 @@ -69,26 +67,18 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma 1/ksize) covs.sort() - #print(covs) if len(covs) == 0: - #print("Zero length") covs.append(0) - - #cov_set = set(covs) len_ind = len(covs)//2 - #print("len_ind") - #print(len_ind) median_cov = covs[len(covs)//2] - #print(median_cov) + pois_obj = poisson(median_cov) #creates a discrete frozen Poisson distribution object cov_max = float('inf') if median_cov < 30: #if median coverage of 30 is not fulfilled - #print(f"Below 30") for i in range(len_ind,len(covs), 1): - #print(i) cov = covs[i] if pois_obj.cdf(cov) < PVALUE_CUTOFF: cov_max = cov @@ -101,7 +91,6 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma if cov <= cov_max: full_covs.append(cov) var = variation(full_covs) - #print("Variation is:", var) if var is not None: logger.debug("VAR {} {}", var, genome_sig.name) @@ -109,7 +98,6 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma geq1_mean_cov = sum(full_covs)//len(covs) if median_cov > MEDIAN_ANI_THRESHOLD: return_lambda = ADJUST_STATUS_HIGH - #print(f"Above_threshold: {type(return_lambda).__name__}") #for testing else: if (myArgs.ratio == True): @@ -128,53 +116,38 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma else: return_lambda = AdjustStatusLambda(value=test_lambda) # Wrap the float in the dataclass - #print(f"Return lambda type: {type(return_lambda).__name__}") - match return_lambda: case AdjustStatusLambda(value=lam): - #print(f"Case1") # executes if it is the Lambda case final_est_cov = lam opt_lambda = final_est_cov - #print(f"Status is Lambda, coverage set to: {final_est_cov:.2f}") case AdjustStatusHigh(): # executes if it is high coverage case - #print(f"Case2") if median_cov < MAX_MEDIAN_FOR_MEAN_FINAL_EST: final_est_cov = geq1_mean_cov - #print(f"Status is High, using geq1_mean_cov logic") else: final_est_cov = median_cov - #print(f"Status is High, using median_cov logic") opt_lambda = final_est_cov case AdjustStatusLow(): - #print(f"Case3") # if it is the "low" case # final_est_cov logic is handled elsewhere, or use a default opt_lambda = None - #print("Status is Low, using naive_ani logic later") # Adding a "wild-card" case, just to be safe case _: - #print(f"Case Wildcard: Unexpected value or type {return_lambda}") opt_lambda = None - #print(f"Opt_lambda") - #print(opt_lambda) - - #print(f"opt_est_ani") opt_est_ani = ani_from_lambda(opt_lambda, mean_cov, 31, full_covs) - #print(opt_est_ani) if opt_lambda == None or opt_est_ani == None or no_adj == True: final_est_ani = naive_ani else: final_est_ani = opt_est_ani -#### This is the "winner_map" situation. I'm leaving it out of the codebase for now, but we can revisit this +# This is the "winner_map" section. I'm leaving it out of the codebase for now in case we would like to revisit this # Calculate min_ani using a conditional expression (Python's 'if/else if/else') #if args.minimum_ani is not None: @@ -213,16 +186,13 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma low_lambda = bootstrap[2] high_lambda = bootstrap[3] - #print(f"ci_values are as follows:") #for testing - #print(low_ani, high_ani, low_lambda, high_lambda) - - if sample_sig.name: + if sample_sig.name: seq_name = sample_sig.name else: seq_name = sample_sig.filename -#This is more code related to the winner_map situation - kmers_lost = kmers_lost_count if winner_map is not None else None +#This is code related to the winner_map situation + #kmers_lost = kmers_lost_count if winner_map is not None else None ani_result = AniResult( naive_ani=naive_ani, @@ -240,7 +210,7 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma genome_sketch=genome_sig, rel_abund=None, seq_abund=None, - kmers_lost=kmers_lost, + kmers_lost=None, ) results = [ @@ -249,7 +219,7 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma seq_name=seq_name, gn_name=genome_sig.filename, contig_name=genome_sig.name, mean_cov=geq1_mean_cov, median_cov=median_cov, containment_index=(contain_count, len(gn_hashes)), lambda_status=return_lambda, ani_ci=(low_ani, high_ani), lambda_ci=(low_lambda, high_lambda), - genome_sketch=genome_sig, rel_abund=None, seq_abund=None, kmers_lost=kmers_lost, + genome_sketch=genome_sig, rel_abund=None, seq_abund=None, kmers_lost=None, )] columns_ani = [ diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index 750bb403..b0d46c93 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -182,8 +182,6 @@ def __find_exclusive_hashes( single_occurrence_hashes.add(hash) - #print(multiple_occurrence_hashes) - del multiple_occurrence_hashes # free up memory # Find hashes that are unique to each organism @@ -196,8 +194,6 @@ def __find_exclusive_hashes( ) ) - #print(f"Single occurrence hashes") - #print(single_occurrence_hashes) #adding this for testing del single_occurrence_hashes # free up memory # Get sample hashes @@ -205,7 +201,6 @@ def __find_exclusive_hashes( # Get sample hashes keys sample_hashes_keys = sample_sig.minhash.hashes.keys() - #print(sample_hashes_keys) samp_kmers_items = sample_sig.minhash.hashes.items() samp_dict = dict(samp_kmers_items) @@ -228,7 +223,6 @@ def __find_exclusive_hashes( for i, exclusive_hashes in enumerate( tqdm(exclusive_hashes_org, desc="Matching exclusive hashes with sample") ): - #print(f"exclusive_hash", exclusive_hashes) exclusive_hashes_info.append( (len(exclusive_hashes), len(exclusive_hashes.intersection(sample_hashes))) ) @@ -237,10 +231,6 @@ def __find_exclusive_hashes( #logger.info("Calculate lambda for each organism in the sample") #for i, lambda_stats in enumerate() - #print(type(exclusive_hashes_info)) - #print(exclusive_hashes_info) - #print(type(sub_manifest)) - columns_of_interest = [ 'naive_ani', 'final_est_ani', @@ -253,14 +243,9 @@ def __find_exclusive_hashes( # Select only those columns from the DataFrame selected_data = final_stats_df[columns_of_interest] - summary_stats = selected_data.describe() - #print(sub_manifest) - print(final_stats_df) print(summary_stats) - #print(final_stats_df['lambda_ci'].unique()) - #print(final_stats_df['ani_ci'].unique()) return exclusive_hashes_info, sub_manifest, final_stats_df @@ -471,7 +456,6 @@ def hypothesis_recovery( # Create a pandas dataframe to store the results results = pd.DataFrame(results, columns=given_columns) - #print(results) #for testing # combine the results with the manifest manifest["min_coverage"] = min_coverage diff --git a/src/yacht/utils.py b/src/yacht/utils.py index 38afcc54..90d61da8 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -579,8 +579,7 @@ def newton_raphson(ratio: float, mean: float): Shaw and Yu (2024)'s implmentation of Newton-Raphson use to assist in the calculation of lambda. """ curr = mean / (1 - ratio) - #print(1-mean) - #print(1-ratio) + for _ in range(1000): #iterates to converge on an approximation for the root t1 = (1 - ratio) * curr e_curr = math.exp(-curr) @@ -771,8 +770,6 @@ def bootstrap_interval(covs_full: list[int], k: float, args: _ContainArgs): if args.ci_int == False: return (None, None, None, None) - #logger.info("Bootstrap interval") - #print(f"Bootstrap interval") #for testing #TODO 12/3: look into whether/where this function is being activated num_samp = len(covs_full) iters = 100 res_ani = [] @@ -793,9 +790,6 @@ def bootstrap_interval(covs_full: list[int], k: float, args: _ContainArgs): else: lambda_val = ratio_lambda(rand_vec, args.min_count_correct) - #print(f"lambda_val is:") #for testing - #print(lambda_val) - ani_val = ani_from_lambda(lambda_val, np.mean(rand_vec), ksize, rand_vec) if ani_val is not None and lambda_val is not None: @@ -815,8 +809,6 @@ def bootstrap_interval(covs_full: list[int], k: float, args: _ContainArgs): low_lambda = res_lambda[suc * 5 // 100] high_lambda = res_lambda[suc * 95 // 100] - #print(f"Bootstrap interval") #for testing - #print(low_ani, high_ani, low_lambda, high_lambda) return (low_ani, high_ani, low_lambda, high_lambda) def ani_from_lambda(lambda_val, lam_mean, k_value, full_cov): From 6f0d8ff5a665880c6732ce0783b5be1600692205 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 5 Dec 2025 15:31:07 -0500 Subject: [PATCH 04/41] Moved test script [skip ci] --- .../yacht/internal_superyacht_test.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename internal_superyacht_test.py => src/yacht/internal_superyacht_test.py (100%) diff --git a/internal_superyacht_test.py b/src/yacht/internal_superyacht_test.py similarity index 100% rename from internal_superyacht_test.py rename to src/yacht/internal_superyacht_test.py From e09ec7b82a89185b720147e65d952c7afc8ce7ea Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 5 Dec 2025 15:38:03 -0500 Subject: [PATCH 05/41] Corrected incorrect indent. [skip ci] --- src/yacht/cov_calc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/yacht/cov_calc.py b/src/yacht/cov_calc.py index 64a508b6..d6cdea90 100644 --- a/src/yacht/cov_calc.py +++ b/src/yacht/cov_calc.py @@ -186,7 +186,7 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma low_lambda = bootstrap[2] high_lambda = bootstrap[3] - if sample_sig.name: + if sample_sig.name: seq_name = sample_sig.name else: seq_name = sample_sig.filename From 50ff040c5e5206d8c950fd2b39c8772478a01936 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Mon, 29 Dec 2025 22:37:54 -0500 Subject: [PATCH 06/41] Various code improvements to improve legibility and remove stray bits that aren't necessary. --- src/yacht/cov_calc.py | 20 +++++++++---------- src/yacht/hypothesis_recovery_src.py | 4 +--- src/yacht/utils.py | 4 ++-- .../internal_superyacht_test.py | 0 4 files changed, 13 insertions(+), 15 deletions(-) rename {src/yacht => tests}/internal_superyacht_test.py (100%) diff --git a/src/yacht/cov_calc.py b/src/yacht/cov_calc.py index d6cdea90..a2277d44 100644 --- a/src/yacht/cov_calc.py +++ b/src/yacht/cov_calc.py @@ -83,7 +83,14 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma if pois_obj.cdf(cov) < PVALUE_CUTOFF: cov_max = cov else: - break #consider adding RaiseError if (e.g.) cov_max=Inf + break + # Check if cov_max remains inf (i.e. no valid maximum found) + if cov_max == float('inf'): + logger.waning( + f"Could not determine valid coverage maximum for geneome {genome_sig.name}." + f"Median coverage: {median_cov}. Returning None." + ) + return None full_covs = [0] * (len(gn_hashes) - contain_count) @@ -213,14 +220,7 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma kmers_lost=None, ) - results = [ - AniResult( - naive_ani=naive_ani, final_est_ani=final_est_ani, final_est_cov=opt_lambda, - seq_name=seq_name, gn_name=genome_sig.filename, contig_name=genome_sig.name, - mean_cov=geq1_mean_cov, median_cov=median_cov, containment_index=(contain_count, len(gn_hashes)), - lambda_status=return_lambda, ani_ci=(low_ani, high_ani), lambda_ci=(low_lambda, high_lambda), - genome_sketch=genome_sig, rel_abund=None, seq_abund=None, kmers_lost=None, - )] + results = [ani_result] columns_ani = [ "naive_ani", # the ani according to naive calculations @@ -247,4 +247,4 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma - \ No newline at end of file + diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index b0d46c93..df418d66 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -14,12 +14,10 @@ from .utils import load_signature_with_ksize, decompress_all_sig_files # Configure Loguru logger from loguru import logger -from cov_calc import cov_calc +from .cov_calc import cov_calc warnings.filterwarnings("ignore") - -\ logger.remove() logger.add( sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO" diff --git a/src/yacht/utils.py b/src/yacht/utils.py index 90d61da8..48c0d193 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -618,7 +618,7 @@ def mle_zip(full_covs: list[int], _k: float): lambda_ret = lambda_out return lambda_ret -def variance(data: str(int)): +def variance(data: List[int]): """ An internal function that calculates the variance, which is the average of the squared differences of all values from the mean """ @@ -849,4 +849,4 @@ def ani_from_lambda(lambda_val, lam_mean, k_value, full_cov): else: ret_ani = ani - return ret_ani \ No newline at end of file + return ret_ani diff --git a/src/yacht/internal_superyacht_test.py b/tests/internal_superyacht_test.py similarity index 100% rename from src/yacht/internal_superyacht_test.py rename to tests/internal_superyacht_test.py From 4dd7972ae5a81be3fb192d629362bfc6088e520e Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 2 Jan 2026 16:54:59 -0500 Subject: [PATCH 07/41] Various bug fixes and improvements to code quality. Replaced relevant print statements with logger. Moved all constants to utils.py. --- src/yacht/cov_calc.py | 23 +++++++++-------------- src/yacht/hypothesis_recovery_src.py | 7 ++----- src/yacht/utils.py | 22 +++++++++++++--------- 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/src/yacht/cov_calc.py b/src/yacht/cov_calc.py index a2277d44..80646e6c 100644 --- a/src/yacht/cov_calc.py +++ b/src/yacht/cov_calc.py @@ -13,16 +13,10 @@ from yacht.utils import _ContainArgs from yacht.utils import AniResult from yacht.utils import AdjustStatusLambda, AdjustStatusLow, AdjustStatusHigh, AdjustStatusNone +from yacht.utils import SAMPLE_SIZE_CUTOFF, PVALUE_CUTOFF, MEDIAN_ANI_THRESHOLD, MAX_MEDIAN_FOR_MEAN_FINAL_EST, MIN_COUNT_THRESH, ksize from scipy.stats import poisson, variation from typing import Optional, Tuple, Dict, Any -SAMPLE_SIZE_CUTOFF: int = 25 #using the sylph (Shaw and Yu, 2024) defaults here -PVALUE_CUTOFF: float = 0.9999999999 -MEDIAN_ANI_THRESHOLD: float = 2.00 -MAX_MEDIAN_FOR_MEAN_FINAL_EST: float = 15.0 -MIN_COUNT_THRESH=3 -ksize=31 - no_adj = False #consider updating this in future SUPERYACHT arguments winner_map = None #skipping this step in this version kmers_lost_count = None @@ -84,13 +78,14 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma cov_max = cov else: break - # Check if cov_max remains inf (i.e. no valid maximum found) - if cov_max == float('inf'): - logger.waning( - f"Could not determine valid coverage maximum for geneome {genome_sig.name}." - f"Median coverage: {median_cov}. Returning None." - ) - return None + + # Check if cov_max remains inf (i.e. no valid maximum found) + if cov_max == float('inf'): + logger.warning( + f"Could not determine valid coverage maximum for genome {genome_sig.name}. " + f"Median coverage: {median_cov}. Returning None." + ) + return None full_covs = [0] * (len(gn_hashes) - contain_count) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index df418d66..a72665d1 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -142,9 +142,6 @@ def get_exclusive_hashes( 2. the number of unique hashes exclusive to the organism under consideration that are in the sample a new manifest dataframe that only contains the organisms that have non-zero overlap with the sample """ - pvalue_cutoff=0.9999999999 - min_count_thresh=3 #TODO: consider whether to change this value - def __find_exclusive_hashes( md5sum: str, path_to_temp_dir: str, @@ -243,7 +240,7 @@ def __find_exclusive_hashes( selected_data = final_stats_df[columns_of_interest] summary_stats = selected_data.describe() - print(summary_stats) + #print(summary_stats) return exclusive_hashes_info, sub_manifest, final_stats_df @@ -450,7 +447,7 @@ def hypothesis_recovery( for i in range(len(exclusive_hashes_info)) ) results = p.starmap(single_hyp_test, params) - print(f"Finished computing all results for min_coverage value: {min_coverage}") #for testing + logger.info(f"Finished computing all results for min_coverage value: {min_coverage}") # Create a pandas dataframe to store the results results = pd.DataFrame(results, columns=given_columns) diff --git a/src/yacht/utils.py b/src/yacht/utils.py index 48c0d193..6d62f07d 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -17,6 +17,7 @@ import sourmash from dataclasses import dataclass from glob import glob +from scipy.special import gamma # Configure Loguru logger logger.remove() @@ -27,10 +28,13 @@ # Set up constants COL_NOT_FOUND_ERROR = "Column not found: {}" FILE_LOCATION = os.path.dirname(os.path.realpath(__file__)) -# Adding two more contstants (RTR) -SAMPLE_SIZE_CUTOFF: int = 25 +# Sylph (Shaw and Yu, 2024) related constants +SAMPLE_SIZE_CUTOFF: int = 25 PVALUE_CUTOFF: float = 0.9999999999 -ksize = 31 #Note: hard-coding this for now +MEDIAN_ANI_THRESHOLD: float = 2.00 +MAX_MEDIAN_FOR_MEAN_FINAL_EST: float = 15.0 +MIN_COUNT_THRESH: int = 3 +ksize: int = 31 # Note: hard-coding this for now # Set up global variables __version__ = "2.0.1" @@ -725,30 +729,30 @@ def mme_lambda(full_covs: list[int]) -> Optional[float]: def binary_search_lambda(full_covs: list[int]): if len(full_covs) == 0: return None - m = mean(full_covs) - v = var(full_covs) + m = np.mean(full_covs) + v = variance(full_covs) nonzero = 0 ones = 0 twos = 0 for x in full_covs: if x != 0: - _nonzero += 1 + nonzero += 1 if x == 1: ones += 1 elif x == 2: twos += 1 + if ones == 0: + return None ratio_est = float(twos) / float(ones) left = float(max(0.003, m - 2)) right = m + 5 - endpoints = ("start", "end") - left, right = endpoints best = None best_val = 10000 for i in range(10000): - test = (endpoints - endpoints)/10000 * float(i) + endpoints + test = (right - left)/10000 * float(i) + left proposed = ratio_from_moments_lambda(1, test, m, v); if proposed is not None: p = proposed - ratio_est From 92e25cc6a091ac01dcbfb4e1cfda7f1ed20b6d0d Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Tue, 6 Jan 2026 17:47:34 -0500 Subject: [PATCH 08/41] Changed to a more typically python enum for cov_calc. --- src/yacht/cov_calc.py | 30 +++++++++++----------------- src/yacht/utils.py | 46 ++++++++++++++++++++++++++++++++----------- 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/src/yacht/cov_calc.py b/src/yacht/cov_calc.py index 80646e6c..f11e05fb 100644 --- a/src/yacht/cov_calc.py +++ b/src/yacht/cov_calc.py @@ -12,7 +12,7 @@ from yacht.utils import ani_from_lambda from yacht.utils import _ContainArgs from yacht.utils import AniResult -from yacht.utils import AdjustStatusLambda, AdjustStatusLow, AdjustStatusHigh, AdjustStatusNone +from yacht.utils import AdjustStatus, AdjustStatusType from yacht.utils import SAMPLE_SIZE_CUTOFF, PVALUE_CUTOFF, MEDIAN_ANI_THRESHOLD, MAX_MEDIAN_FOR_MEAN_FINAL_EST, MIN_COUNT_THRESH, ksize from scipy.stats import poisson, variation from typing import Optional, Tuple, Dict, Any @@ -21,14 +21,6 @@ winner_map = None #skipping this step in this version kmers_lost_count = None -# Creates instances of the simple states -ADJUST_STATUS_NONE = AdjustStatusNone() -ADJUST_STATUS_HIGH = AdjustStatusHigh() -ADJUST_STATUS_LOW = AdjustStatusLow() - -# Define a Union type hint for clarity -AdjustStatus = AdjustStatusLambda | AdjustStatusHigh | AdjustStatusLow | AdjustStatusNone - def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.SourmashSignature): """ Function that calculates lambda according to Shaw and Yu (2024) from two sourmash.Minshash files (resresenting the sample and the genome sketches). @@ -97,14 +89,14 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma logger.debug("VAR {} {}", var, genome_sig.name) mean_cov = sum(full_covs)//len(full_covs) - geq1_mean_cov = sum(full_covs)//len(covs) + geq1_mean_cov = sum(full_covs)//len(covs) if median_cov > MEDIAN_ANI_THRESHOLD: - return_lambda = ADJUST_STATUS_HIGH + return_lambda = AdjustStatus.high() else: if (myArgs.ratio == True): test_lambda = ratio_lambda(full_covs, MIN_COUNT_THRESH) - elif (myArgs.mme == True): + elif (myArgs.mme == True): test_lambda = mme_lambda(full_covs) elif (myArgs.bin == True): test_lambda = binary_search_lambda(full_covs) @@ -114,18 +106,18 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma test_lambda = ratio_lambda(full_covs, MIN_COUNT_THRESH) if test_lambda is None: - return_lambda = ADJUST_STATUS_LOW #updated code + return_lambda = AdjustStatus.low() else: - return_lambda = AdjustStatusLambda(value=test_lambda) # Wrap the float in the dataclass + return_lambda = AdjustStatus.lambda_value(test_lambda) - match return_lambda: + match return_lambda.status: - case AdjustStatusLambda(value=lam): + case AdjustStatusType.LAMBDA: # executes if it is the Lambda case - final_est_cov = lam + final_est_cov = return_lambda.value opt_lambda = final_est_cov - case AdjustStatusHigh(): + case AdjustStatusType.HIGH: # executes if it is high coverage case if median_cov < MAX_MEDIAN_FOR_MEAN_FINAL_EST: final_est_cov = geq1_mean_cov @@ -133,7 +125,7 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma final_est_cov = median_cov opt_lambda = final_est_cov - case AdjustStatusLow(): + case AdjustStatusType.LOW: # if it is the "low" case # final_est_cov logic is handled elsewhere, or use a default opt_lambda = None diff --git a/src/yacht/utils.py b/src/yacht/utils.py index 6d62f07d..f3ccbd7e 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -43,19 +43,41 @@ BASE_URL = "https://farm.cse.ucdavis.edu/~ctbrown/sourmash-db/" ZENODO_COMMUNITY_URL = "https://zenodo.org/api/records/?communities=yacht&size=100" -# A dataclass to implement something equivalent to sylph's rust-based enum implementation (AdjustStatus) -@dataclass(frozen=True) -class AdjustStatusLambda: - value: float - -class AdjustStatusHigh: - pass +# Pythonic enum implementation for lambda adjustment status +from enum import Enum -class AdjustStatusLow: - pass +class AdjustStatusType(Enum): + """Status types for lambda adjustment.""" + LAMBDA = "lambda" + HIGH = "high" + LOW = "low" + NONE = "none" -class AdjustStatusNone: - pass +@dataclass(frozen=True) +class AdjustStatus: + """Lambda adjustment status with optional value.""" + status: AdjustStatusType + value: Optional[float] = None + + @classmethod + def lambda_value(cls, value: float): + """Create a Lambda status with a value.""" + return cls(AdjustStatusType.LAMBDA, value) + + @classmethod + def high(cls): + """Create a High status.""" + return cls(AdjustStatusType.HIGH) + + @classmethod + def low(cls): + """Create a Low status.""" + return cls(AdjustStatusType.LOW) + + @classmethod + def none(cls): + """Create a None status.""" + return cls(AdjustStatusType.NONE) # Class for cov_calc output @dataclass @@ -69,7 +91,7 @@ class AniResult: mean_cov: float median_cov: float containment_index: Tuple[int, int] - lambda_status: AdjustStatusLambda + lambda_status: AdjustStatus ani_ci: Tuple[Optional[float], Optional[float]] lambda_ci: Tuple[Optional[float], Optional[float]] genome_sketch: Any From 80bedab1faaebdc1dddfb8e346e6399777102aa6 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Sat, 10 Jan 2026 17:11:16 -0500 Subject: [PATCH 09/41] Minor update to README and .gitignore. --- .gitignore | 9 +++++++ README.md | 79 +++++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 82 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index c9474082..d1920cb0 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,12 @@ gtdb-rs214-reps.k31_0.9995_pretrained/ # added by mahmudhera src/cpp/main.o .gitignore + +# added by rtraborn +tests_sra_data/ +*.fastq +*.fastq.gz +*.sig.zip +*_temp/ +*_intermediate_files/ +*.sra \ No newline at end of file diff --git a/README.md b/README.md index a15e3d6a..8eddd8f3 100644 --- a/README.md +++ b/README.md @@ -341,22 +341,33 @@ The `--min_coverage_list` parameter dictates a list of `min_coverage` which indi The output file will be an EXCEL file; column descriptions can be found [here](docs/column_descriptions.csv). The most important are the following: +**Core Detection Columns:** * `organism_name`: The name of the organism -* `in_sample_est`: A boolean value either False or True: if False, there was not enough evidence to claim this organism is present in the sample. +* `in_sample_est`: A boolean value either False or True: if False, there was not enough evidence to claim this organism is present in the sample. * `p_vals`: Probability of observing this or more extreme result at the given ANI threshold, assuming the null hypothesis. - -Other interesting columns include: - * `num_exclusive_kmers_to_genome`: How many k-mers were found in this organism and no others * `num_matches`: How many k-mers were found in this organism and the sample * `acceptance_threshold_*`: How many k-mers must be found in this organism to be considered "present" at the given ANI threshold. Hence, `in_sample_est` is True if `num_matches` >= `acceptance_threshold_*` (adjusting by coverage if desired). * `alt_confidence_mut_rate_*`: What the mutation rate (1-ANI) would need to be to get your false positive to match the false negative rate of 1-`significance` (adjusting by coverage if desired). +**Coverage Statistics (described in sylph: Shaw & Yu, 2024 | New in superyacht):** +* `naive_ani`: Simple ANI estimate from k-mer containment (0-1, multiply by 100 for percentage) +* `final_est_ani`: Coverage-adjusted ANI estimate (more accurate than naive_ani) +* `final_est_cov`: Expected coverage (lambda parameter) - average sequencing depth for this organism +* `mean_cov` / `median_cov`: Coverage distribution statistics +* `lambda_status`: Coverage calculation method (LAMBDA/HIGH/LOW) + +**Relative Abundance (Winner Map):** +* `rel_abund`: Relative abundance estimate (0-1, normalized across all organisms in sample) +* `kmers_lost`: Number of k-mers reassigned to organisms with higher ANI + +**ANI Filtering:** Organisms with `final_est_ani < 0.90` (90% ANI) are automatically filtered from results to remove low-quality matches. +
### 4. Convert YACHT result to other popular output formats (yacht convert) -When we get the EXCEL result file from run_YACHT.py, you can run `yacht convert` to covert the YACHT result to other popular output formats (Currently, only `cami`, `biom`, `graphplan` are supported). +When you get the EXCEL result file from run_YACHT.py, you can run `yacht convert` to covert the YACHT result to other popular output formats (Currently, only `cami`, `biom`, `graphplan` are supported). __Note__: Before you run `yacht convert`, you need to prepare a TSV file `genome_to_taxid.tsv` containing two columns: genome ID (genome_id) and its corresponding taxid (taxid). An example can be found [here](demo/toy_genome_to_taxid.tsv). You need to prepare it according to the reference database genomes you used. @@ -374,8 +385,64 @@ yacht convert --yacht_output 'result.xlsx' --sheet_name 'min_coverage0.01' --gen | --genome_to_taxid | the path to the location of `genome_to_taxid.tsv` you prepared | | --mode | specify to which output format you want to convert (e.g., 'cami', 'biom', 'graphplan') | --sample_name | A random name you would like to show in header of the cami file. Default: Sample1.' | -| --outfile_prefix | the prefix of the output file. Default: result | +| --outfile_prefix | the prefix of the output file. Default: result | | --outdir | the path to output directory where the results will be genreated | +
+ +## Methods + +YACHT integrates three complementary approaches for organism detection and quantification: + +### 1. Hypothesis Testing (Presence/Absence) +- Uses **exclusive k-mers** (unique to each organism) for statistical testing +- Binomial test determines if enough exclusive k-mers observed to reject null hypothesis +- Accounts for mutation rate and minimum coverage thresholds +- Output: `in_sample_est` (True/False), `p_vals` + +### 2. Coverage Modeling (ANI & Abundance) +- Implements Shaw & Yu (2024) method from [sylph](https://github.com/bluenote-1577/sylph) +- Estimates effective ANI and coverage accounting for sequencing depth variation +- Calculates expected coverage (lambda) using Poisson distribution modeling +- More accurate than naive containment-based ANI +- Output: `final_est_ani`, `final_est_cov`, coverage statistics + +### 3. Winner Map K-mer Reassignment (Relative Abundance) +- Follows "winner takes all" strategy from sylph +- Shared k-mers between organisms assigned to organism with highest ANI +- Prevents double-counting in relative abundance calculations +- Single-pass design: uses coverage results without recalculation +- Output: `rel_abund` (normalized), `kmers_lost` + +### Workflow +``` +Sample Input + ↓ +Filter organisms (any k-mer overlap) + ↓ +Find exclusive k-mers (for hypothesis testing) + ↓ +Calculate coverage & ANI (Shaw & Yu method) + ↓ +Build winner_map (k-mer reassignment) + ↓ +Estimate relative abundance + ↓ +Run hypothesis tests + ↓ +Merge results + Filter by ANI threshold + ↓ +Excel Output +``` + +### Key Features +- **Exclusive k-mers**: Prevent false positives from shared sequences +- **Coverage adjustment**: More accurate ANI estimates +- **Single-pass performance**: Efficient computation +- **ANI filtering**: Removes low-quality matches (< 90% ANI) + +### References +1. Koslicki, D., White, S., Ma, C., & Novikov, A. (2024). YACHT: an ANI-based statistical test to detect microbial presence/absence in a metagenomic sample. *Bioinformatics*, 40(2), btae047. +2. Shaw, J., & Yu, Y. W. (2024). Rapid species-level metagenome profiling and containment estimation with sylph. *Nature Biotechnology*. https://doi.org/10.1038/s41587-024-02412-y From b1bd7264bfe71acd6223654f45f217cf9b038900 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Mon, 12 Jan 2026 14:37:56 -0500 Subject: [PATCH 10/41] Added winner map functionality to hypothesis_recovery_src.py. --- src/yacht/hypothesis_recovery_src.py | 220 +++++++++++++++++++++++---- 1 file changed, 189 insertions(+), 31 deletions(-) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index a72665d1..aed86679 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -10,12 +10,19 @@ from multiprocessing import Pool import sourmash import glob -from typing import List, Set, Tuple -from .utils import load_signature_with_ksize, decompress_all_sig_files +from typing import List, Set, Tuple, Dict +from .utils import load_signature_with_ksize, decompress_all_sig_files, MIN_ANI_THRESHOLD # Configure Loguru logger from loguru import logger from .cov_calc import cov_calc +""" +Hypothesis Recovery and Coverage Analysis Module + +This module implements YACHT's core statistical framework for organism detection +in metagenomic samples, with integrated coverage modeling and abundance estimation. +""" + warnings.filterwarnings("ignore") logger.remove() @@ -193,12 +200,8 @@ def __find_exclusive_hashes( # Get sample hashes sample_hashes = set(sample_sig.minhash.hashes) - - # Get sample hashes keys - sample_hashes_keys = sample_sig.minhash.hashes.keys() - samp_kmers_items = sample_sig.minhash.hashes.items() - samp_dict = dict(samp_kmers_items) - + + # Calculate coverage statistics for each organism stats_list = [] for md5sum in tqdm(organism_md5sum_list, desc="Processing coverage per organism"): sig = load_signature_with_ksize( @@ -207,9 +210,12 @@ def __find_exclusive_hashes( ) stats_out = cov_calc(sample_sig, sig) #location of cov_calc, which calculates effective coverage and other things according to Shaw and Yu (2024) stats_list.append(stats_out) - + final_stats_df = pd.concat(stats_list, ignore_index=True) - + + # Add organism_name to final_stats_df for merging (stats are in same order as sub_manifest) + final_stats_df['organism_name'] = sub_manifest['organism_name'].values + del stats_list # free up memory # Find hashes that are unique to each organism and in the sample @@ -222,27 +228,121 @@ def __find_exclusive_hashes( (len(exclusive_hashes), len(exclusive_hashes.intersection(sample_hashes))) ) - # Calculate lambda and other related coverage metrics for each organism in the sample - #logger.info("Calculate lambda for each organism in the sample") - #for i, lambda_stats in enumerate() - - columns_of_interest = [ - 'naive_ani', - 'final_est_ani', - 'final_est_cov', - 'mean_cov', - 'median_cov', - 'lambda_ci', - 'ani_ci' - ] + logger.info("Building winner map for k-mer reassignment and relative abundance estimation") + winner_map = build_winner_map(final_stats_df, path_to_genome_temp_dir, ksize) + final_stats_df = estimate_relative_abundance(final_stats_df, winner_map, sample_sig) - # Select only those columns from the DataFrame - selected_data = final_stats_df[columns_of_interest] - summary_stats = selected_data.describe() + return exclusive_hashes_info, sub_manifest, final_stats_df - #print(summary_stats) - return exclusive_hashes_info, sub_manifest, final_stats_df +def build_winner_map( + final_stats_df: pd.DataFrame, + path_to_genome_temp_dir: str, + ksize: int +) -> Dict[int, Tuple[float, str]]: + """ + Creates a "winner map" procedure that assigns k-mers to the organism with the highest ANI. + + This implements the "winner takes all" strategy from sylph (Shaw and Yu, 2024) where + shared k-mers are assigned to the organism with the best ANI match, preventing double-counting. + Please note that this differs from the approach in sylph in that the procedure is run once, rather than + twice. + + :param final_stats_df: DataFrame with coverage statistics including organism_name, + final_est_ani, and genome_sketch columns + :param path_to_genome_temp_dir: Path to the directory containing genome signature files + :param ksize: k-mer size + :return: Dictionary mapping k-mer hash -> (ani, organism_name) + Only the organism with highest ANI "wins" each k-mer + """ + from .utils import load_signature_with_ksize + + winner_map = {} + + logger.info("Building winner map for k-mer reassignment") + + for idx, row in tqdm(final_stats_df.iterrows(), total=len(final_stats_df), desc="Building winner map"): + organism_name = row['organism_name'] + ani = row['final_est_ani'] + + # Skip organisms with no ANI estimate + if pd.isna(ani): + continue + + # Load genome signature to get k-mers + genome_sig = row['genome_sketch'] + + # For each k-mer in this genome, check if it should be reassigned + for kmer in genome_sig.minhash.hashes.keys(): + if kmer not in winner_map or ani > winner_map[kmer][0]: + # This organism has higher ANI, so it "wins" this k-mer + winner_map[kmer] = (ani, organism_name) + + logger.info(f"Winner map built with {len(winner_map)} k-mers assigned to {len(final_stats_df)} organisms") + + return winner_map + + +def estimate_relative_abundance( + final_stats_df: pd.DataFrame, + winner_map: Dict[int, Tuple[float, str]], + sample_sig: sourmash.SourmashSignature +) -> pd.DataFrame: + """ + Estimates the relative abundance of each organism based on winner_map k-mer assignments. + + After winner_map assigns shared k-mers to organisms with highest ANI, this calculates: + 1. How many k-mers each organism "lost" to others (kmers_lost) + 2. Total coverage from k-mers "won" by each organism (used for relative abundance) + 3. Relative abundance normalized across all organisms + + :param final_stats_df: DataFrame with coverage statistics + :param winner_map: k-mer to (ANI, organism_name) mapping from build_winner_map() + :param sample_sig: Sample signature with k-mer abundances + :return: Updated DataFrame with rel_abund and kmers_lost columns populated + """ + logger.info("Estimating relative abundance using winner map") + + # Initialize columns + final_stats_df['kmers_lost'] = 0 + final_stats_df['rel_abund'] = 0.0 + + sample_hashes = sample_sig.minhash.hashes + + for idx, row in tqdm(final_stats_df.iterrows(), total=len(final_stats_df), desc="Calculating relative abundance"): + organism_name = row['organism_name'] + genome_sig = row['genome_sketch'] + + kmers_lost_count = 0 + total_coverage = 0.0 + + # Check each k-mer in this genome + for kmer in genome_sig.minhash.hashes.keys(): + # Check if this organism "won" this k-mer + if kmer in winner_map: + winner_organism = winner_map[kmer][1] + + if winner_organism != organism_name: + # This k-mer was reassigned to another organism + kmers_lost_count += 1 + else: + # This organism won this k-mer - count its coverage + if kmer in sample_hashes: + total_coverage += sample_hashes[kmer] + + final_stats_df.at[idx, 'kmers_lost'] = kmers_lost_count + final_stats_df.at[idx, 'rel_abund'] = total_coverage + + # Normalize relative abundance to sum to 1.0 across all organisms + total_abundance = final_stats_df['rel_abund'].sum() + if total_abundance > 0: + final_stats_df['rel_abund'] = final_stats_df['rel_abund'] / total_abundance + logger.info(f"Relative abundance normalized (total coverage: {total_abundance:.2f})") + else: + logger.warning("No coverage found for relative abundance calculation") + + return final_stats_df + def get_alt_mut_rate( nu: int, thresh: int, ksize: int, significance: float = 0.99 @@ -288,7 +388,6 @@ def single_hyp_test( """ # get the number of unique k-mers num_exclusive_kmers = exclusive_hashes_info_org[0] - #print(exclusive_hashes_info_org) ##printing the output of this to determine what the data structure looks like # mutation rate non_mut_p = (ani_thresh) ** ksize # # assuming coverage of 1, how many unique k-mers would I need to observe in order to reject the null hypothesis? @@ -321,7 +420,6 @@ def single_hyp_test( # How many unique k-mers do I actually see? num_matches = exclusive_hashes_info_org[1] - #print(num_matches) #printing this for testing # calculate the p-value considering the coverage if num_matches <= num_exclusive_kmers_coverage: p_val = binom.cdf(num_matches, num_exclusive_kmers_coverage, non_mut_p) @@ -456,4 +554,64 @@ def hypothesis_recovery( manifest["min_coverage"] = min_coverage manifest_list.append(pd.concat([manifest, results], axis=1)) - return manifest_list, final_stats_df + # Merge coverage statistics into each manifest DataFrame + # Select key coverage columns to include in output (including winner_map results) + coverage_cols = [ + 'organism_name', + 'naive_ani', + 'final_est_ani', + 'final_est_cov', + 'mean_cov', + 'median_cov', + 'lambda_status', + 'ani_ci', + 'lambda_ci', + 'rel_abund', # Relative abundance from winner_map + 'kmers_lost' # K-mers reassigned to other organisms + ] + coverage_stats = final_stats_df[coverage_cols].copy() + + # Merge coverage stats into each manifest in the list + for i in range(len(manifest_list)): + manifest_list[i] = manifest_list[i].merge( + coverage_stats, + on='organism_name', + how='left' # Keep all organisms, even those without coverage stats + ) + + # ============================================================================ + # ANI Threshold Filtering + # ============================================================================ + # + # After winner_map k-mer reassignment, filter organisms with low ANI. + # + # Why filter? + # - Removes poor matches (distant relatives, contamination, low complexity) + # - Improves result quality by eliminating noise + # - Standard practice in metagenomic profiling + # + # Threshold: MIN_ANI_THRESHOLD = 0.90 (90% ANI) + # - Matches sylph's MIN_ANI_DEF default + # - 90% ANI commonly used for genus-level distinction + # - Well-supported by microbial genomics literature + # + # Implementation: + # Keep organisms with final_est_ani >= 0.90 OR NaN ANI + # (NaN kept because hypothesis test may still be valid via exclusive k-mers) + # + # Customization: + # To change threshold, modify MIN_ANI_THRESHOLD in utils.py + # + logger.info(f"Filtering organisms with final_est_ani < {MIN_ANI_THRESHOLD} ({MIN_ANI_THRESHOLD*100:.0f}% ANI)") + for i in range(len(manifest_list)): + initial_count = len(manifest_list[i]) + # Keep organisms with ANI >= threshold OR organisms with no ANI estimate (NaN) + manifest_list[i] = manifest_list[i][ + (manifest_list[i]['final_est_ani'] >= MIN_ANI_THRESHOLD) | + (manifest_list[i]['final_est_ani'].isna()) + ].reset_index(drop=True) + filtered_count = initial_count - len(manifest_list[i]) + if filtered_count > 0: + logger.info(f" Filtered {filtered_count} organisms below ANI threshold from min_coverage={manifest_list[i]['min_coverage'].iloc[0] if len(manifest_list[i]) > 0 else 'N/A'} results") + + return manifest_list From bab404fde315d270abf96741f4354057dfa2251a Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Mon, 12 Jan 2026 15:29:15 -0500 Subject: [PATCH 11/41] Minor update to .gitignore. --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index d1920cb0..ce9ee735 100644 --- a/.gitignore +++ b/.gitignore @@ -180,4 +180,5 @@ tests_sra_data/ *.sig.zip *_temp/ *_intermediate_files/ -*.sra \ No newline at end of file +*.srademo/query_data/*.zip +*.sra From b8b476a1fe0b856b9b9ca797ba86eedeb86405e8 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Wed, 14 Jan 2026 12:03:45 -0500 Subject: [PATCH 12/41] Updated utils.py to define MIN_ANI_THRESHOLD. --- src/yacht/utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/yacht/utils.py b/src/yacht/utils.py index f3ccbd7e..bef73c3f 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -31,9 +31,11 @@ # Sylph (Shaw and Yu, 2024) related constants SAMPLE_SIZE_CUTOFF: int = 25 PVALUE_CUTOFF: float = 0.9999999999 +MIN_ANI_THRESHOLD: float = 0.90 # Minimum ANI threshold for filtering organisms MEDIAN_ANI_THRESHOLD: float = 2.00 MAX_MEDIAN_FOR_MEAN_FINAL_EST: float = 15.0 MIN_COUNT_THRESH: int = 3 +LAMBDA_EPSILON: float = 1e-10 # Minimum lambda value to avoid dividing by zero ksize: int = 31 # Note: hard-coding this for now # Set up global variables @@ -847,22 +849,27 @@ def ani_from_lambda(lambda_val, lam_mean, k_value, full_cov): full_cov: A list of integers to analyze for non-zero counts. Returns: - An optional float representing the calculated adjusted index 'ani', + An optional float representing the calculated adjusted index 'ani', or None if the input lambda is None, or if 'ani' is negative or NaN. """ if lambda_val == None: return None + + # Check if lambda is too close to zero so that we avoid dividing by zero + if abs(lambda_val) < LAMBDA_EPSILON: + return None + contain_count = 0 zero_count = 0 for x in full_cov: if x != 0: contain_count += 1 - + else: zero_count += 1 if not full_cov: - return None + return None adj_index = contain_count / (1.0 - math.exp(-lambda_val)) / len(full_cov) From 1f81fccbd93d1c3aac62f6cd21f74fb557e3e39d Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Wed, 14 Jan 2026 13:52:06 -0500 Subject: [PATCH 13/41] Added parallelization to the coverage calculation; should increase speed. --- src/yacht/hypothesis_recovery_src.py | 55 +++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index aed86679..466f6e39 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -119,13 +119,33 @@ def get_organisms_with_nonzero_overlap( return multisearch_result["match_name"].to_list() +def _calculate_coverage_worker(args): + """ + Worker function for parallel coverage calculation. + + :param args: Tuple of (md5sum, path_to_genome_temp_dir, ksize, sample_sig) + :return: Result from cov_calc or None if error occurs + """ + md5sum, path_to_genome_temp_dir, ksize, sample_sig = args + try: + sig = load_signature_with_ksize( + os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), + ksize, + ) + return cov_calc(sample_sig, sig) + except Exception as e: + logger.warning(f"Error calculating coverage for {md5sum}: {e}") + return None + + def get_exclusive_hashes( manifest: pd.DataFrame, nontrivial_organism_names: List[str], sample_sig: sourmash.SourmashSignature, ksize: int, path_to_genome_temp_dir: str, -) -> Tuple[List[Tuple[int, int]], pd.DataFrame]: + num_threads: int = 16, +) -> Tuple[List[Tuple[int, int]], pd.DataFrame, pd.DataFrame]: """ This function gets the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample, and then find how many are in the sample. @@ -143,11 +163,13 @@ def get_exclusive_hashes( :param sample_sig: the sample signature :param ksize: int (size of kmer) :param path_to_genome_temp_dir: string (path to the genome temporary directory generated by the training step) + :param num_threads: int (number of threads to use for parallel coverage calculation, default: 16) :return: a list of tuples, each tuple contains the following information: 1. the number of unique hashes exclusive to the organism under consideration 2. the number of unique hashes exclusive to the organism under consideration that are in the sample a new manifest dataframe that only contains the organisms that have non-zero overlap with the sample + a dataframe with coverage statistics for each organism """ def __find_exclusive_hashes( md5sum: str, @@ -201,15 +223,28 @@ def __find_exclusive_hashes( # Get sample hashes sample_hashes = set(sample_sig.minhash.hashes) - # Calculate coverage statistics for each organism - stats_list = [] - for md5sum in tqdm(organism_md5sum_list, desc="Processing coverage per organism"): - sig = load_signature_with_ksize( - os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), - ksize, + # Calculate coverage statistics for each organism (parallelized) + logger.info(f"Calculating coverage statistics using {num_threads} threads") + with Pool(processes=num_threads) as pool: + # Prepare arguments for parallel processing + args_list = [ + (md5sum, path_to_genome_temp_dir, ksize, sample_sig) + for md5sum in organism_md5sum_list + ] + # Use imap for progress tracking with tqdm + stats_list = list( + tqdm( + pool.imap(_calculate_coverage_worker, args_list), + total=len(organism_md5sum_list), + desc="Processing coverage per organism" ) - stats_out = cov_calc(sample_sig, sig) #location of cov_calc, which calculates effective coverage and other things according to Shaw and Yu (2024) - stats_list.append(stats_out) + ) + + # Filter out None results (from errors) + stats_list = [stats for stats in stats_list if stats is not None] + + if not stats_list: + raise ValueError("No coverage statistics were successfully calculated") final_stats_df = pd.concat(stats_list, ignore_index=True) @@ -508,7 +543,7 @@ def hypothesis_recovery( # Get the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample exclusive_hashes_info, manifest, final_stats_df = get_exclusive_hashes( - manifest, nontrivial_organism_names, sample_sig, ksize, path_to_genome_temp_dir + manifest, nontrivial_organism_names, sample_sig, ksize, path_to_genome_temp_dir, num_threads ) # Set up the results dataframe columns From 12fcde9d6019614d332c1bf528a03023167b2a5e Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Wed, 14 Jan 2026 14:31:46 -0500 Subject: [PATCH 14/41] Cleaning up documentation. --- README.md | 55 ------------------------------------------------------- 1 file changed, 55 deletions(-) diff --git a/README.md b/README.md index 8eddd8f3..de4bcf3a 100644 --- a/README.md +++ b/README.md @@ -390,59 +390,4 @@ yacht convert --yacht_output 'result.xlsx' --sheet_name 'min_coverage0.01' --gen
-## Methods - -YACHT integrates three complementary approaches for organism detection and quantification: - -### 1. Hypothesis Testing (Presence/Absence) -- Uses **exclusive k-mers** (unique to each organism) for statistical testing -- Binomial test determines if enough exclusive k-mers observed to reject null hypothesis -- Accounts for mutation rate and minimum coverage thresholds -- Output: `in_sample_est` (True/False), `p_vals` - -### 2. Coverage Modeling (ANI & Abundance) -- Implements Shaw & Yu (2024) method from [sylph](https://github.com/bluenote-1577/sylph) -- Estimates effective ANI and coverage accounting for sequencing depth variation -- Calculates expected coverage (lambda) using Poisson distribution modeling -- More accurate than naive containment-based ANI -- Output: `final_est_ani`, `final_est_cov`, coverage statistics - -### 3. Winner Map K-mer Reassignment (Relative Abundance) -- Follows "winner takes all" strategy from sylph -- Shared k-mers between organisms assigned to organism with highest ANI -- Prevents double-counting in relative abundance calculations -- Single-pass design: uses coverage results without recalculation -- Output: `rel_abund` (normalized), `kmers_lost` - -### Workflow -``` -Sample Input - ↓ -Filter organisms (any k-mer overlap) - ↓ -Find exclusive k-mers (for hypothesis testing) - ↓ -Calculate coverage & ANI (Shaw & Yu method) - ↓ -Build winner_map (k-mer reassignment) - ↓ -Estimate relative abundance - ↓ -Run hypothesis tests - ↓ -Merge results + Filter by ANI threshold - ↓ -Excel Output -``` - -### Key Features -- **Exclusive k-mers**: Prevent false positives from shared sequences -- **Coverage adjustment**: More accurate ANI estimates -- **Single-pass performance**: Efficient computation -- **ANI filtering**: Removes low-quality matches (< 90% ANI) - -### References -1. Koslicki, D., White, S., Ma, C., & Novikov, A. (2024). YACHT: an ANI-based statistical test to detect microbial presence/absence in a metagenomic sample. *Bioinformatics*, 40(2), btae047. -2. Shaw, J., & Yu, Y. W. (2024). Rapid species-level metagenome profiling and containment estimation with sylph. *Nature Biotechnology*. https://doi.org/10.1038/s41587-024-02412-y - From 94cb13adf5bdf7b6db3214ff02477ed60238ded6 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Wed, 14 Jan 2026 17:22:58 -0500 Subject: [PATCH 15/41] Added sample_sig as a pool variable to remove big performance overheads. --- src/yacht/hypothesis_recovery_src.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index 466f6e39..adbbb67b 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -119,20 +119,33 @@ def get_organisms_with_nonzero_overlap( return multisearch_result["match_name"].to_list() +# Adding a global variable for sharing sample_sig across worker processes (reduces overhead) +_worker_sample_sig = None + +def _init_coverage_worker(sample_sig): + """ + Initializer for worker processes to set up shared sample signature. + + :param sample_sig: Sample signature to be shared across all workers + """ + global _worker_sample_sig + _worker_sample_sig = sample_sig + def _calculate_coverage_worker(args): """ Worker function for parallel coverage calculation. + Uses global _worker_sample_sig instead of passing it as argument to avoid overhead. - :param args: Tuple of (md5sum, path_to_genome_temp_dir, ksize, sample_sig) + :param args: Tuple of (md5sum, path_to_genome_temp_dir, ksize) :return: Result from cov_calc or None if error occurs """ - md5sum, path_to_genome_temp_dir, ksize, sample_sig = args + md5sum, path_to_genome_temp_dir, ksize = args try: sig = load_signature_with_ksize( os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), ksize, ) - return cov_calc(sample_sig, sig) + return cov_calc(_worker_sample_sig, sig) except Exception as e: logger.warning(f"Error calculating coverage for {md5sum}: {e}") return None @@ -225,10 +238,10 @@ def __find_exclusive_hashes( # Calculate coverage statistics for each organism (parallelized) logger.info(f"Calculating coverage statistics using {num_threads} threads") - with Pool(processes=num_threads) as pool: - # Prepare arguments for parallel processing + with Pool(processes=num_threads, initializer=_init_coverage_worker, initargs=(sample_sig,)) as pool: + # Prepare arguments for parallel processing (sample_sig shared to avoid overhead) args_list = [ - (md5sum, path_to_genome_temp_dir, ksize, sample_sig) + (md5sum, path_to_genome_temp_dir, ksize) for md5sum in organism_md5sum_list ] # Use imap for progress tracking with tqdm From 1f643c1994c917bde66f6771700e594b7b326c3f Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Thu, 15 Jan 2026 15:24:21 -0500 Subject: [PATCH 16/41] Improvements to parallel processing and new arguments for the winner-takes-all k-mer reassignment. --- src/yacht/hypothesis_recovery_src.py | 166 ++++++++++++++++++--------- src/yacht/run_YACHT.py | 26 +++++ 2 files changed, 136 insertions(+), 56 deletions(-) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index adbbb67b..80446a66 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -119,7 +119,7 @@ def get_organisms_with_nonzero_overlap( return multisearch_result["match_name"].to_list() -# Adding a global variable for sharing sample_sig across worker processes (reduces overhead) +# Global variable for sharing sample signature across worker processes _worker_sample_sig = None def _init_coverage_worker(sample_sig): @@ -134,7 +134,7 @@ def _init_coverage_worker(sample_sig): def _calculate_coverage_worker(args): """ Worker function for parallel coverage calculation. - Uses global _worker_sample_sig instead of passing it as argument to avoid overhead. + Uses global _worker_sample_sig instead of passing it as argument to avoid pickling overhead. :param args: Tuple of (md5sum, path_to_genome_temp_dir, ksize) :return: Result from cov_calc or None if error occurs @@ -158,6 +158,8 @@ def get_exclusive_hashes( ksize: int, path_to_genome_temp_dir: str, num_threads: int = 16, + winner_takes_all: bool = False, + batch_size: int = 1000, ) -> Tuple[List[Tuple[int, int]], pd.DataFrame, pd.DataFrame]: """ This function gets the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample, and @@ -177,6 +179,8 @@ def get_exclusive_hashes( :param ksize: int (size of kmer) :param path_to_genome_temp_dir: string (path to the genome temporary directory generated by the training step) :param num_threads: int (number of threads to use for parallel coverage calculation, default: 16) + :param winner_takes_all: bool (enable winner-takes-all k-mer reassignment and relative abundance, default: False) + :param batch_size: int (batch size for winner-takes-all processing, default: 1000) :return: a list of tuples, each tuple contains the following information: 1. the number of unique hashes exclusive to the organism under consideration @@ -238,16 +242,21 @@ def __find_exclusive_hashes( # Calculate coverage statistics for each organism (parallelized) logger.info(f"Calculating coverage statistics using {num_threads} threads") + + # Calculate (organism) chunk size for progress visibility + chunk_size = max(1, len(organism_md5sum_list) // (num_threads * 50)) + logger.info(f"Using chunk size of {chunk_size} for parallel processing") + with Pool(processes=num_threads, initializer=_init_coverage_worker, initargs=(sample_sig,)) as pool: - # Prepare arguments for parallel processing (sample_sig shared to avoid overhead) + # Prepare arguments for parallel processing (sample_sig shared via initializer to avoid pickling overhead) args_list = [ (md5sum, path_to_genome_temp_dir, ksize) for md5sum in organism_md5sum_list ] - # Use imap for progress tracking with tqdm + # Use imap_unordered for better performance (order doesn't matter for coverage stats) stats_list = list( tqdm( - pool.imap(_calculate_coverage_worker, args_list), + pool.imap_unordered(_calculate_coverage_worker, args_list, chunksize=chunk_size), total=len(organism_md5sum_list), desc="Processing coverage per organism" ) @@ -276,9 +285,25 @@ def __find_exclusive_hashes( (len(exclusive_hashes), len(exclusive_hashes.intersection(sample_hashes))) ) - logger.info("Building winner map for k-mer reassignment and relative abundance estimation") - winner_map = build_winner_map(final_stats_df, path_to_genome_temp_dir, ksize) - final_stats_df = estimate_relative_abundance(final_stats_df, winner_map, sample_sig) + # Conditionally run winner-takes-all (memory-intensive but provides relative abundance) + if winner_takes_all: + logger.info("Building winner map for k-mer reassignment (winner-takes-all strategy)") + winner_map = build_winner_map(final_stats_df, path_to_genome_temp_dir, ksize, batch_size) + final_stats_df = estimate_relative_abundance(final_stats_df, winner_map, sample_sig, batch_size) + + # Free up memory by dropping genome_sketch column (no longer needed) + if 'genome_sketch' in final_stats_df.columns: + logger.info("Releasing genome signature objects to free memory") + final_stats_df.drop(columns=['genome_sketch'], inplace=True) + else: + logger.info("Skipping winner-takes-all (not enabled). Use --winner_takes_all to enable relative abundance estimation.") + # Add placeholder columns for consistency + final_stats_df['rel_abund'] = float('nan') + final_stats_df['kmers_lost'] = 0 + + # Also drop genome_sketch to save memory + if 'genome_sketch' in final_stats_df.columns: + final_stats_df.drop(columns=['genome_sketch'], inplace=True) return exclusive_hashes_info, sub_manifest, final_stats_df @@ -286,10 +311,12 @@ def __find_exclusive_hashes( def build_winner_map( final_stats_df: pd.DataFrame, path_to_genome_temp_dir: str, - ksize: int + ksize: int, + batch_size: int = 1000 ) -> Dict[int, Tuple[float, str]]: """ Creates a "winner map" procedure that assigns k-mers to the organism with the highest ANI. + Uses memory-efficient batch processing. This implements the "winner takes all" strategy from sylph (Shaw and Yu, 2024) where shared k-mers are assigned to the organism with the best ANI match, preventing double-counting. @@ -300,33 +327,42 @@ def build_winner_map( final_est_ani, and genome_sketch columns :param path_to_genome_temp_dir: Path to the directory containing genome signature files :param ksize: k-mer size + :param batch_size: Number of organisms to process per batch (default: 1000) :return: Dictionary mapping k-mer hash -> (ani, organism_name) Only the organism with highest ANI "wins" each k-mer """ - from .utils import load_signature_with_ksize - winner_map = {} + total_organisms = len(final_stats_df) + + logger.info(f"Building winner map for {total_organisms} organisms (batch size: {batch_size})") + + # Process in batches to control memory usage + for batch_start in range(0, total_organisms, batch_size): + batch_end = min(batch_start + batch_size, total_organisms) + batch_num = batch_start // batch_size + 1 + total_batches = (total_organisms + batch_size - 1) // batch_size + + for idx in tqdm( + range(batch_start, batch_end), + desc=f"Building winner map (batch {batch_num}/{total_batches})", + total=batch_end - batch_start + ): + row = final_stats_df.iloc[idx] + organism_name = row['organism_name'] + ani = row['final_est_ani'] + + # Skip organisms with no ANI estimate + if pd.isna(ani): + continue - logger.info("Building winner map for k-mer reassignment") - - for idx, row in tqdm(final_stats_df.iterrows(), total=len(final_stats_df), desc="Building winner map"): - organism_name = row['organism_name'] - ani = row['final_est_ani'] - - # Skip organisms with no ANI estimate - if pd.isna(ani): - continue - - # Load genome signature to get k-mers - genome_sig = row['genome_sketch'] + genome_sig = row['genome_sketch'] - # For each k-mer in this genome, check if it should be reassigned - for kmer in genome_sig.minhash.hashes.keys(): - if kmer not in winner_map or ani > winner_map[kmer][0]: - # This organism has higher ANI, so it "wins" this k-mer - winner_map[kmer] = (ani, organism_name) + # For each k-mer, check if it should be reassigned + for kmer in genome_sig.minhash.hashes.keys(): + if kmer not in winner_map or ani > winner_map[kmer][0]: + winner_map[kmer] = (ani, organism_name) - logger.info(f"Winner map built with {len(winner_map)} k-mers assigned to {len(final_stats_df)} organisms") + logger.info(f"Winner map built with {len(winner_map)} k-mers assigned") return winner_map @@ -334,10 +370,12 @@ def build_winner_map( def estimate_relative_abundance( final_stats_df: pd.DataFrame, winner_map: Dict[int, Tuple[float, str]], - sample_sig: sourmash.SourmashSignature + sample_sig: sourmash.SourmashSignature, + batch_size: int = 1000 ) -> pd.DataFrame: """ Estimates the relative abundance of each organism based on winner_map k-mer assignments. + Uses memory-efficient batch processing. After winner_map assigns shared k-mers to organisms with highest ANI, this calculates: 1. How many k-mers each organism "lost" to others (kmers_lost) @@ -347,39 +385,52 @@ def estimate_relative_abundance( :param final_stats_df: DataFrame with coverage statistics :param winner_map: k-mer to (ANI, organism_name) mapping from build_winner_map() :param sample_sig: Sample signature with k-mer abundances + :param batch_size: Number of organisms to process per batch (default: 1000) :return: Updated DataFrame with rel_abund and kmers_lost columns populated """ - logger.info("Estimating relative abundance using winner map") + logger.info("Estimating relative abundance using winner map (memory-efficient mode)") # Initialize columns final_stats_df['kmers_lost'] = 0 final_stats_df['rel_abund'] = 0.0 sample_hashes = sample_sig.minhash.hashes - - for idx, row in tqdm(final_stats_df.iterrows(), total=len(final_stats_df), desc="Calculating relative abundance"): - organism_name = row['organism_name'] - genome_sig = row['genome_sketch'] - - kmers_lost_count = 0 - total_coverage = 0.0 - - # Check each k-mer in this genome - for kmer in genome_sig.minhash.hashes.keys(): - # Check if this organism "won" this k-mer - if kmer in winner_map: - winner_organism = winner_map[kmer][1] - - if winner_organism != organism_name: - # This k-mer was reassigned to another organism - kmers_lost_count += 1 - else: - # This organism won this k-mer - count its coverage - if kmer in sample_hashes: - total_coverage += sample_hashes[kmer] - - final_stats_df.at[idx, 'kmers_lost'] = kmers_lost_count - final_stats_df.at[idx, 'rel_abund'] = total_coverage + total_organisms = len(final_stats_df) + + # Process in batches to reduce memory usage + for batch_start in range(0, total_organisms, batch_size): + batch_end = min(batch_start + batch_size, total_organisms) + batch_num = batch_start // batch_size + 1 + total_batches = (total_organisms + batch_size - 1) // batch_size + + for idx in tqdm( + range(batch_start, batch_end), + desc=f"Calculating relative abundance (batch {batch_num}/{total_batches})", + total=batch_end - batch_start + ): + row = final_stats_df.iloc[idx] + organism_name = row['organism_name'] + genome_sig = row['genome_sketch'] + + kmers_lost_count = 0 + total_coverage = 0.0 + + # Check each k-mer in this genome + for kmer in genome_sig.minhash.hashes.keys(): + # Check if this organism "won" this k-mer + if kmer in winner_map: + winner_organism = winner_map[kmer][1] + + if winner_organism != organism_name: + # This k-mer was reassigned to another organism + kmers_lost_count += 1 + else: + # This organism won this k-mer - count its coverage + if kmer in sample_hashes: + total_coverage += sample_hashes[kmer] + + final_stats_df.at[idx, 'kmers_lost'] = kmers_lost_count + final_stats_df.at[idx, 'rel_abund'] = total_coverage # Normalize relative abundance to sum to 1.0 across all organisms total_abundance = final_stats_df['rel_abund'].sum() @@ -502,6 +553,8 @@ def hypothesis_recovery( significance: float = 0.99, ani_thresh: float = 0.95, num_threads: int = 16, + winner_takes_all: bool = False, + batch_size: int = 1000, ): """ Go through each of the organisms that have non-zero overlap with the sample and perform a hypothesis test for the @@ -556,7 +609,8 @@ def hypothesis_recovery( # Get the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample exclusive_hashes_info, manifest, final_stats_df = get_exclusive_hashes( - manifest, nontrivial_organism_names, sample_sig, ksize, path_to_genome_temp_dir, num_threads + manifest, nontrivial_organism_names, sample_sig, ksize, path_to_genome_temp_dir, + num_threads, winner_takes_all, batch_size ) # Set up the results dataframe columns diff --git a/src/yacht/run_YACHT.py b/src/yacht/run_YACHT.py index dc5d5b02..ea768295 100644 --- a/src/yacht/run_YACHT.py +++ b/src/yacht/run_YACHT.py @@ -45,6 +45,21 @@ def add_arguments(parser): required=False, default=16, ) + parser.add_argument( + "--winner_takes_all", + action="store_true", + help="Enable winner-takes-all k-mer reassignment and relative abundance estimation. " + "Shared k-mers are assigned to the organism with highest ANI. " + "Uses memory-efficient batch processing. More accurate but slower.", + default=False, + ) + parser.add_argument( + "--batch_size", + type=int, + help="Batch size for winner-takes-all processing (lower = less memory, slower). " + "Only used with --winner_takes_all. Default: 1000", + default=1000, + ) parser.add_argument( "--keep_raw", action="store_true", help="Keep raw results in output file." ) @@ -77,6 +92,8 @@ def main(args): sample_file = str(Path(args.sample_file).absolute()) # location of sample.sig file significance = args.significance # Minimum probability of individual true negative. num_threads = args.num_threads # Number of threads to use for parallelization. + winner_takes_all = args.winner_takes_all # Enable winner-takes-all k-mer reassignment + batch_size = args.batch_size # Batch size for winner-takes-all processing keep_raw = args.keep_raw # Keep raw results in output file. show_all = args.show_all # Show all organisms (no matter if present) in output file. min_coverage_list = args.min_coverage_list # a list of percentages of unique k-mers covered by reads in the sample. @@ -84,6 +101,13 @@ def main(args): outdir = os.path.dirname(out) # path to output directory out_filename = os.path.basename(out) # output filename + # Validate that batch_size is only used with winner_takes_all + if batch_size != 1000 and not winner_takes_all: + raise ValueError( + "--batch_size can only be used with --winner_takes_all. " + "Either remove --batch_size or add --winner_takes_all." + ) + # check if the output filename is valid if os.path.splitext(out_filename)[1] != ".xlsx": raise ValueError( @@ -192,6 +216,8 @@ def main(args): significance, ani_thresh, num_threads, + winner_takes_all, + batch_size, ) # remove unnecessary columns From ef9077d0c21b79f4688f54b180f8590e20dfc236 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Thu, 15 Jan 2026 16:23:55 -0500 Subject: [PATCH 17/41] Small tweaks to help statement for new winner-map related arguments. --- src/yacht/run_YACHT.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/yacht/run_YACHT.py b/src/yacht/run_YACHT.py index ea768295..246d917f 100644 --- a/src/yacht/run_YACHT.py +++ b/src/yacht/run_YACHT.py @@ -48,15 +48,15 @@ def add_arguments(parser): parser.add_argument( "--winner_takes_all", action="store_true", - help="Enable winner-takes-all k-mer reassignment and relative abundance estimation. " - "Shared k-mers are assigned to the organism with highest ANI. " - "Uses memory-efficient batch processing. More accurate but slower.", + help="Enables winner-takes-all k-mer reassignment and relative abundance estimation. " + "Shared k-mers are assigned to the taxon with the highest ANI. " + "Uses more memory-efficient batch processing. ", default=False, ) parser.add_argument( "--batch_size", type=int, - help="Batch size for winner-takes-all processing (lower = less memory, slower). " + help="Batch size for winner-takes-all processing (lower size uses less memory)." "Only used with --winner_takes_all. Default: 1000", default=1000, ) From 45faa7408250bf87f9652bfb283efc495f857d20 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Thu, 15 Jan 2026 16:32:27 -0500 Subject: [PATCH 18/41] Fixed a small typo. --- src/yacht/run_YACHT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/yacht/run_YACHT.py b/src/yacht/run_YACHT.py index 246d917f..1bf5c960 100644 --- a/src/yacht/run_YACHT.py +++ b/src/yacht/run_YACHT.py @@ -56,7 +56,7 @@ def add_arguments(parser): parser.add_argument( "--batch_size", type=int, - help="Batch size for winner-takes-all processing (lower size uses less memory)." + help="Batch size for winner-takes-all processing (lower size uses less memory). " "Only used with --winner_takes_all. Default: 1000", default=1000, ) From 6997b5c8b16711dfe5cc6f6f49334b7db0f944d5 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Thu, 15 Jan 2026 17:14:31 -0500 Subject: [PATCH 19/41] Renamed MEDIAN_ANI_THRESHOLD to avoid confusion. --- src/yacht/cov_calc.py | 4 ++-- src/yacht/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/yacht/cov_calc.py b/src/yacht/cov_calc.py index f11e05fb..47238dd9 100644 --- a/src/yacht/cov_calc.py +++ b/src/yacht/cov_calc.py @@ -13,7 +13,7 @@ from yacht.utils import _ContainArgs from yacht.utils import AniResult from yacht.utils import AdjustStatus, AdjustStatusType -from yacht.utils import SAMPLE_SIZE_CUTOFF, PVALUE_CUTOFF, MEDIAN_ANI_THRESHOLD, MAX_MEDIAN_FOR_MEAN_FINAL_EST, MIN_COUNT_THRESH, ksize +from yacht.utils import SAMPLE_SIZE_CUTOFF, PVALUE_CUTOFF, MEDIAN_COV_THRESHOLD, MAX_MEDIAN_FOR_MEAN_FINAL_EST, MIN_COUNT_THRESH, ksize from scipy.stats import poisson, variation from typing import Optional, Tuple, Dict, Any @@ -90,7 +90,7 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma mean_cov = sum(full_covs)//len(full_covs) geq1_mean_cov = sum(full_covs)//len(covs) - if median_cov > MEDIAN_ANI_THRESHOLD: + if median_cov > MEDIAN_COV_THRESHOLD: return_lambda = AdjustStatus.high() else: diff --git a/src/yacht/utils.py b/src/yacht/utils.py index bef73c3f..af378c92 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -32,7 +32,7 @@ SAMPLE_SIZE_CUTOFF: int = 25 PVALUE_CUTOFF: float = 0.9999999999 MIN_ANI_THRESHOLD: float = 0.90 # Minimum ANI threshold for filtering organisms -MEDIAN_ANI_THRESHOLD: float = 2.00 +MEDIAN_COV_THRESHOLD: float = 2.00 MAX_MEDIAN_FOR_MEAN_FINAL_EST: float = 15.0 MIN_COUNT_THRESH: int = 3 LAMBDA_EPSILON: float = 1e-10 # Minimum lambda value to avoid dividing by zero From 3333649565c9b8221d0b4526180e2073a7b60a2f Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 16 Jan 2026 11:21:33 -0500 Subject: [PATCH 20/41] Fixed critical bug due to umap_unordered- now matching on organism name. --- src/yacht/hypothesis_recovery_src.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index 80446a66..0b12d94e 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -136,18 +136,22 @@ def _calculate_coverage_worker(args): Worker function for parallel coverage calculation. Uses global _worker_sample_sig instead of passing it as argument to avoid pickling overhead. - :param args: Tuple of (md5sum, path_to_genome_temp_dir, ksize) - :return: Result from cov_calc or None if error occurs + :param args: Tuple of (md5sum, organism_name, path_to_genome_temp_dir, ksize) + :return: Result DataFrame from cov_calc with organism_name, or None if error occurs """ - md5sum, path_to_genome_temp_dir, ksize = args + md5sum, organism_name, path_to_genome_temp_dir, ksize = args try: sig = load_signature_with_ksize( os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), ksize, ) - return cov_calc(_worker_sample_sig, sig) + result_df = cov_calc(_worker_sample_sig, sig) + if result_df is not None: + # Add organism_name to the result for proper matching + result_df['organism_name'] = organism_name + return result_df except Exception as e: - logger.warning(f"Error calculating coverage for {md5sum}: {e}") + logger.warning(f"Error calculating coverage for {organism_name} ({md5sum}): {e}") return None @@ -243,17 +247,19 @@ def __find_exclusive_hashes( # Calculate coverage statistics for each organism (parallelized) logger.info(f"Calculating coverage statistics using {num_threads} threads") - # Calculate (organism) chunk size for progress visibility + # Calculates optimal chunk size for progress visibility chunk_size = max(1, len(organism_md5sum_list) // (num_threads * 50)) logger.info(f"Using chunk size of {chunk_size} for parallel processing") with Pool(processes=num_threads, initializer=_init_coverage_worker, initargs=(sample_sig,)) as pool: # Prepare arguments for parallel processing (sample_sig shared via initializer to avoid pickling overhead) + # Include organism_name for proper matching (fixes misalignment bug from imap_unordered) + organism_name_list = sub_manifest["organism_name"].to_list() args_list = [ - (md5sum, path_to_genome_temp_dir, ksize) - for md5sum in organism_md5sum_list + (md5sum, organism_name, path_to_genome_temp_dir, ksize) + for md5sum, organism_name in zip(organism_md5sum_list, organism_name_list) ] - # Use imap_unordered for better performance (order doesn't matter for coverage stats) + # Use imap_unordered for better performance (we are matching on organism_name) stats_list = list( tqdm( pool.imap_unordered(_calculate_coverage_worker, args_list, chunksize=chunk_size), @@ -268,11 +274,9 @@ def __find_exclusive_hashes( if not stats_list: raise ValueError("No coverage statistics were successfully calculated") + # Concatenate all results - organism_name is already included in each DataFrame final_stats_df = pd.concat(stats_list, ignore_index=True) - # Add organism_name to final_stats_df for merging (stats are in same order as sub_manifest) - final_stats_df['organism_name'] = sub_manifest['organism_name'].values - del stats_list # free up memory # Find hashes that are unique to each organism and in the sample From 782b2fdf5248b8c17440de6b6924bab1bb75b05a Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Tue, 27 Jan 2026 17:13:36 -0500 Subject: [PATCH 21/41] Added ANI-capping utility to avoid biologically impossible ANI inflation due to low lambda. --- src/yacht/utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/yacht/utils.py b/src/yacht/utils.py index af378c92..0823c35c 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -32,7 +32,7 @@ SAMPLE_SIZE_CUTOFF: int = 25 PVALUE_CUTOFF: float = 0.9999999999 MIN_ANI_THRESHOLD: float = 0.90 # Minimum ANI threshold for filtering organisms -MEDIAN_COV_THRESHOLD: float = 2.00 +MEDIAN_ANI_THRESHOLD: float = 2.00 MAX_MEDIAN_FOR_MEAN_FINAL_EST: float = 15.0 MIN_COUNT_THRESH: int = 3 LAMBDA_EPSILON: float = 1e-10 # Minimum lambda value to avoid dividing by zero @@ -872,14 +872,20 @@ def ani_from_lambda(lambda_val, lam_mean, k_value, full_cov): return None adj_index = contain_count / (1.0 - math.exp(-lambda_val)) / len(full_cov) - + # Calculating ani using math.pow for clarity (and to maintain type correctness) ani = math.pow(adj_index, 1.0 / k_value) - + + # Caps ANI at 1.0 (100% identity, so impossible to exceed this) + # This can happen in instances with very low lambda values where the adjustment inflates ANI + if ani > 1.0: + logger.debug(f"ANI {ani:.6f} exceeds 1.0 (lambda={lambda_val:.4f}, adj_index={adj_index:.4f}), capping at 1.0") + ani = 1.0 + if ani < 0.0 or math.isnan(ani): ret_ani = None - + else: ret_ani = ani - + return ret_ani From 835f43460847abf28454cabec70c9f9ceca1903a Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Wed, 28 Jan 2026 10:05:33 -0500 Subject: [PATCH 22/41] Updated cov_calc to cap ANIs at 1.0; pt II of ani capping. --- src/yacht/cov_calc.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/yacht/cov_calc.py b/src/yacht/cov_calc.py index 47238dd9..a52df361 100644 --- a/src/yacht/cov_calc.py +++ b/src/yacht/cov_calc.py @@ -13,7 +13,7 @@ from yacht.utils import _ContainArgs from yacht.utils import AniResult from yacht.utils import AdjustStatus, AdjustStatusType -from yacht.utils import SAMPLE_SIZE_CUTOFF, PVALUE_CUTOFF, MEDIAN_COV_THRESHOLD, MAX_MEDIAN_FOR_MEAN_FINAL_EST, MIN_COUNT_THRESH, ksize +from yacht.utils import SAMPLE_SIZE_CUTOFF, PVALUE_CUTOFF, MEDIAN_ANI_THRESHOLD, MAX_MEDIAN_FOR_MEAN_FINAL_EST, MIN_COUNT_THRESH, ksize from scipy.stats import poisson, variation from typing import Optional, Tuple, Dict, Any @@ -48,10 +48,15 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma if len(covs)==0: return None - + naive_ani = math.pow(contain_count/len(gn_kmers_items), 1/ksize) + # Caps naive_ani at 1.0 to prevent biologically impossible ANIs + if naive_ani > 1.0: + logger.debug(f"Naive ANI {naive_ani:.6f} exceeds 1.0, capping at 1.0") + naive_ani = 1.0 + covs.sort() if len(covs) == 0: @@ -90,7 +95,7 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma mean_cov = sum(full_covs)//len(full_covs) geq1_mean_cov = sum(full_covs)//len(covs) - if median_cov > MEDIAN_COV_THRESHOLD: + if median_cov > MEDIAN_ANI_THRESHOLD: return_lambda = AdjustStatus.high() else: From 8b36fb1005a5a4b790205714a9db54f394d1a840 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Thu, 29 Jan 2026 14:54:06 -0500 Subject: [PATCH 23/41] Removed old commented code that was replaced by the winner-map addition. --- src/yacht/cov_calc.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/src/yacht/cov_calc.py b/src/yacht/cov_calc.py index a52df361..8f960e70 100644 --- a/src/yacht/cov_calc.py +++ b/src/yacht/cov_calc.py @@ -146,33 +146,6 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma else: final_est_ani = opt_est_ani -# This is the "winner_map" section. I'm leaving it out of the codebase for now in case we would like to revisit this - -# Calculate min_ani using a conditional expression (Python's 'if/else if/else') - #if args.minimum_ani is not None: - # min_ani = args.minimum_ani / 100.0 - #elif args.pseudotax: - # min_ani = MIN_ANI_P_DEF - #else: - # min_ani = MIN_ANI_DEF - - # Check the final estimated ANI against the calculated minimum - #if final_est_ani < min_ani: - # Use 'is not None' to check for optional values (like Rust's is_some()) - # if winner_map is not None: - # Check if we should log the reassignment event - # if log_reassign: - # logging.info( - # "Genome/contig %s/%s has ANI = %.2f < %.2f after reassigning %d k-mers (%d contained k-mers after reassign)", - # genome_sketch.file_name, - # genome_sketch.first_contig_name, - # final_est_ani * 100.0, - # min_ani * 100.0, - # kmers_lost_count, - # contain_count - # ) - -######## End winner_map section low_ani, high_ani, low_lambda, high_lambda= None, None, None, None From 7e9f9f9ab346a112d0b07b9c7ef59a3c8022e9e6 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Thu, 29 Jan 2026 15:06:13 -0500 Subject: [PATCH 24/41] Updated ANI adjustment to avoid ANI adjustment under high-coverage, update to median_ani_threshold. --- src/yacht/cov_calc.py | 5 +++-- src/yacht/utils.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/yacht/cov_calc.py b/src/yacht/cov_calc.py index 8f960e70..3bbbe441 100644 --- a/src/yacht/cov_calc.py +++ b/src/yacht/cov_calc.py @@ -141,12 +141,13 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma opt_est_ani = ani_from_lambda(opt_lambda, mean_cov, 31, full_covs) - if opt_lambda == None or opt_est_ani == None or no_adj == True: + if opt_lambda == None or opt_est_ani == None or no_adj == True or return_lambda.status == AdjustStatusType.HIGH: + # Avoids adjusting for high-coverage (i.e. where median > MEDIAN_ANI_THRESHOLD) final_est_ani = naive_ani + logger.debug(f"Using naive ANI (no adjustment needed): median_cov={median_cov}, naive_ani={naive_ani:.4f}") else: final_est_ani = opt_est_ani - low_ani, high_ani, low_lambda, high_lambda= None, None, None, None #Conditional calculation of confidence intervals diff --git a/src/yacht/utils.py b/src/yacht/utils.py index 0823c35c..450df3bf 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -32,7 +32,7 @@ SAMPLE_SIZE_CUTOFF: int = 25 PVALUE_CUTOFF: float = 0.9999999999 MIN_ANI_THRESHOLD: float = 0.90 # Minimum ANI threshold for filtering organisms -MEDIAN_ANI_THRESHOLD: float = 2.00 +MEDIAN_ANI_THRESHOLD: float = 3.00 MAX_MEDIAN_FOR_MEAN_FINAL_EST: float = 15.0 MIN_COUNT_THRESH: int = 3 LAMBDA_EPSILON: float = 1e-10 # Minimum lambda value to avoid dividing by zero From d35725dd4e7d37d9529b30cadae7355a18b88065 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Thu, 12 Mar 2026 17:13:31 -0400 Subject: [PATCH 25/41] Patched a minor bug relating to performance of python's gzip on Mac OS, replacing it with the system gzip --- src/yacht/utils.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/yacht/utils.py b/src/yacht/utils.py index 450df3bf..9abb2a1b 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -239,9 +239,10 @@ def run_yacht_train_core( selected_sig_files = pd.read_csv(os.path.join(path_to_temp_dir, 'selected_result.tsv'), sep="\t", header=None) selected_sig_files = selected_sig_files[0].to_list() - # get the mapping from signature file name to genome name - mapping = {sig_info_dict[name][-1]:name for name in sig_info_dict} - selected_genome_names_set = set([mapping[sig_file_path] for sig_file_path in selected_sig_files]) + # get the mapping from signature file name to genome name; normalize to basename for matching. Basename extracted from C++ + mapping = {os.path.basename(sig_info_dict[name][-1]): name for name in sig_info_dict} + selected_genome_names_set = set([mapping[os.path.basename(sig_file_path)] for sig_file_path in selected_sig_files]) + # remove the close related organisms from the reference genome list manifest_df = [] @@ -302,7 +303,13 @@ def collect_signature_info( ], ) - return {sig[1]: (sig[2], sig[3], sig[4], sig[5], sig[0]) for sig in tqdm(signatures) if sig} + def get_key_with_warning(sig): + if not sig[1]: + logger.warning(f"Signature has no name, using md5sum as identifier: {sig[2]}") + return sig[2] + return sig[1] + + return {get_key_with_warning(sig): (sig[2], sig[3], sig[4], sig[5], sig[0]) for sig in tqdm(signatures) if sig} class Prediction: @@ -562,23 +569,28 @@ def check_download_args(args, db_type): logger.error("We now haven't supported for virus database.") sys.exit(1) - def _decompress_and_remove(file_path: str) -> None: """ Decompresses a GZIP-compressed file and removes the original compressed file. :param file_path: The path to the .sig.gz file that needs to be decompressed and deleted. :return: None """ + import subprocess try: output_filename = os.path.splitext(file_path)[0] - with gzip.open(file_path, 'rb') as f_in: - with open(output_filename, 'wb') as f_out: - f_out.write(f_in.read()) - - os.remove(file_path) - + with open(output_filename, 'wb') as f_out: + result = subprocess.run( + ['gunzip', '-c', file_path], + stdout=f_out, + stderr=subprocess.PIPE + ) + if result.returncode == 0: + os.remove(file_path) + else: + logger.info(f"gunzip failed for {file_path}: {result.stderr.decode()}") except Exception as e: logger.info(f"Failed to process {file_path}: {e}") + def decompress_all_sig_files(sig_files: List[str], num_threads: int) -> None: """ From 7f04e173262a7c7e31420369378eac3747a80bc3 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 13 Mar 2026 14:10:48 -0400 Subject: [PATCH 26/41] Added two-pass mode to --winner-take-all procedure. --- src/yacht/hypothesis_recovery_src.py | 244 +++++++++++++++++++++++++-- 1 file changed, 231 insertions(+), 13 deletions(-) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index 0b12d94e..3a7da990 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -11,7 +11,15 @@ import sourmash import glob from typing import List, Set, Tuple, Dict -from .utils import load_signature_with_ksize, decompress_all_sig_files, MIN_ANI_THRESHOLD +from .utils import ( + load_signature_with_ksize, + decompress_all_sig_files, + MIN_ANI_THRESHOLD, + ratio_lambda, + ani_from_lambda, + MIN_COUNT_THRESH, + SAMPLE_SIZE_CUTOFF, +) # Configure Loguru logger from loguru import logger from .cov_calc import cov_calc @@ -164,6 +172,7 @@ def get_exclusive_hashes( num_threads: int = 16, winner_takes_all: bool = False, batch_size: int = 1000, + two_pass: bool = True, ) -> Tuple[List[Tuple[int, int]], pd.DataFrame, pd.DataFrame]: """ This function gets the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample, and @@ -185,6 +194,8 @@ def get_exclusive_hashes( :param num_threads: int (number of threads to use for parallel coverage calculation, default: 16) :param winner_takes_all: bool (enable winner-takes-all k-mer reassignment and relative abundance, default: False) :param batch_size: int (batch size for winner-takes-all processing, default: 1000) + :param two_pass: bool (use sylph's two-pass approach for more accurate k-mer reassignment, default: True) + Set to False for original one-pass behavior (for testing/comparison) :return: a list of tuples, each tuple contains the following information: 1. the number of unique hashes exclusive to the organism under consideration @@ -291,8 +302,31 @@ def __find_exclusive_hashes( # Conditionally run winner-takes-all (memory-intensive but provides relative abundance) if winner_takes_all: - logger.info("Building winner map for k-mer reassignment (winner-takes-all strategy)") - winner_map = build_winner_map(final_stats_df, path_to_genome_temp_dir, ksize, batch_size) + if two_pass: + # Two-pass approach (sylph-aligned): more accurate for closely related organisms + # Pass 1: Build initial winner map using original ANI estimates + logger.info("Pass 1: Building initial winner map for k-mer reassignment") + winner_map = build_winner_map(final_stats_df, path_to_genome_temp_dir, ksize, batch_size) + + # Recalculate ANI using only won k-mers (prepares for Pass 2) + final_stats_df = recalculate_ani_from_winner_map( + final_stats_df, winner_map, sample_sig, ksize, batch_size + ) + + # Pass 2: Rebuild winner map with refined ANI estimates + # Only include organisms that weren't eliminated in the recalculation + logger.info("Pass 2: Rebuilding winner map with refined ANI estimates") + winner_map = build_winner_map(final_stats_df, path_to_genome_temp_dir, ksize, batch_size) + else: + # One-pass approach (original): faster but less accurate for related organisms + logger.info("One-pass mode: Building winner map with initial ANI estimates") + winner_map = build_winner_map(final_stats_df, path_to_genome_temp_dir, ksize, batch_size) + + # Add placeholder columns for consistency + final_stats_df['reassignment_status'] = 'one_pass' + final_stats_df['original_ani'] = final_stats_df['final_est_ani'].copy() + + # Calculate relative abundance using final winner map final_stats_df = estimate_relative_abundance(final_stats_df, winner_map, sample_sig, batch_size) # Free up memory by dropping genome_sketch column (no longer needed) @@ -301,9 +335,11 @@ def __find_exclusive_hashes( final_stats_df.drop(columns=['genome_sketch'], inplace=True) else: logger.info("Skipping winner-takes-all (not enabled). Use --winner_takes_all to enable relative abundance estimation.") - # Add placeholder columns for consistency + # Add placeholder columns for consistency with winner-takes-all output final_stats_df['rel_abund'] = float('nan') final_stats_df['kmers_lost'] = 0 + final_stats_df['reassignment_status'] = 'not_applicable' + final_stats_df['original_ani'] = final_stats_df['final_est_ani'].copy() # Also drop genome_sketch to save memory if 'genome_sketch' in final_stats_df.columns: @@ -371,6 +407,109 @@ def build_winner_map( return winner_map +def recalculate_ani_from_winner_map( + final_stats_df: pd.DataFrame, + winner_map: Dict[int, Tuple[float, str]], + sample_sig: sourmash.SourmashSignature, + ksize: int, + batch_size: int = 1000 +) -> pd.DataFrame: + """ + Recalculates ANI for each organism using only k-mers it 'won' in the winner map. + This implements Pass 2 of sylph's two-pass winner-takes-all approach. + + Organisms that lost all their k-mers are marked as 'eliminated' with rel_abund=0. + This is Option D handling: keep in results but flag as inconclusive. + + :param final_stats_df: DataFrame with coverage statistics including organism_name, + final_est_ani, and genome_sketch columns + :param winner_map: k-mer to (ANI, organism_name) mapping from build_winner_map() + :param sample_sig: Sample signature with k-mer abundances + :param ksize: k-mer size for ANI calculation + :param batch_size: Number of organisms to process per batch (default: 1000) + :return: Updated DataFrame with recalculated ANI values and reassignment_status column + """ + logger.info("Recalculating ANI using only won k-mers (Pass 2 of two-pass approach)") + + sample_hashes = sample_sig.minhash.hashes + total_organisms = len(final_stats_df) + + # Initialize reassignment_status column + final_stats_df['reassignment_status'] = 'active' + + # Store original ANI for reference + final_stats_df['original_ani'] = final_stats_df['final_est_ani'].copy() + + eliminated_count = 0 + + for batch_start in range(0, total_organisms, batch_size): + batch_end = min(batch_start + batch_size, total_organisms) + batch_num = batch_start // batch_size + 1 + total_batches = (total_organisms + batch_size - 1) // batch_size + + for idx in tqdm( + range(batch_start, batch_end), + desc=f"Recalculating ANI (batch {batch_num}/{total_batches})", + total=batch_end - batch_start + ): + row = final_stats_df.iloc[idx] + organism_name = row['organism_name'] + genome_sig = row['genome_sketch'] + + # Count k-mers won by this organism and build coverage list + won_kmers_in_sample = [] + total_won_kmers = 0 + + for kmer in genome_sig.minhash.hashes.keys(): + if kmer in winner_map and winner_map[kmer][1] == organism_name: + total_won_kmers += 1 + if kmer in sample_hashes and sample_hashes[kmer] > 0: + won_kmers_in_sample.append(sample_hashes[kmer]) + + # Handle organisms that lost all k-mers (Option D: mark as eliminated) + if total_won_kmers == 0: + final_stats_df.at[idx, 'reassignment_status'] = 'eliminated' + final_stats_df.at[idx, 'final_est_ani'] = float('nan') + eliminated_count += 1 + continue + + # Check if we have enough data for lambda estimation + if len(won_kmers_in_sample) < SAMPLE_SIZE_CUTOFF: + # Not enough won k-mers in sample - mark as eliminated + final_stats_df.at[idx, 'reassignment_status'] = 'eliminated' + final_stats_df.at[idx, 'final_est_ani'] = float('nan') + eliminated_count += 1 + continue + + # Build full_cov array (zeros for won k-mers not in sample + coverages for those in sample) + num_zeros = total_won_kmers - len(won_kmers_in_sample) + full_cov = [0] * num_zeros + won_kmers_in_sample + + # Recalculate lambda using ratio method + new_lambda = ratio_lambda(full_cov, MIN_COUNT_THRESH) + + if new_lambda is None: + # Lambda estimation failed - keep original ANI but mark status + final_stats_df.at[idx, 'reassignment_status'] = 'lambda_failed' + # Keep original ANI value (don't modify final_est_ani) + continue + + # Recalculate ANI from new lambda + mean_cov = sum(full_cov) / len(full_cov) if full_cov else 0 + new_ani = ani_from_lambda(new_lambda, mean_cov, ksize, full_cov) + + if new_ani is not None: + final_stats_df.at[idx, 'final_est_ani'] = new_ani + else: + # ANI calculation failed - mark status but keep original + final_stats_df.at[idx, 'reassignment_status'] = 'ani_failed' + + logger.info(f"ANI recalculation complete: {eliminated_count} organisms eliminated, " + f"{total_organisms - eliminated_count} remain active") + + return final_stats_df + + def estimate_relative_abundance( final_stats_df: pd.DataFrame, winner_map: Dict[int, Tuple[float, str]], @@ -386,6 +525,8 @@ def estimate_relative_abundance( 2. Total coverage from k-mers "won" by each organism (used for relative abundance) 3. Relative abundance normalized across all organisms + Organisms marked as 'eliminated' in reassignment_status are skipped (rel_abund = 0). + :param final_stats_df: DataFrame with coverage statistics :param winner_map: k-mer to (ANI, organism_name) mapping from build_winner_map() :param sample_sig: Sample signature with k-mer abundances @@ -401,6 +542,9 @@ def estimate_relative_abundance( sample_hashes = sample_sig.minhash.hashes total_organisms = len(final_stats_df) + # Check if reassignment_status column exists (from two-pass approach) + has_reassignment_status = 'reassignment_status' in final_stats_df.columns + # Process in batches to reduce memory usage for batch_start in range(0, total_organisms, batch_size): batch_end = min(batch_start + batch_size, total_organisms) @@ -414,6 +558,11 @@ def estimate_relative_abundance( ): row = final_stats_df.iloc[idx] organism_name = row['organism_name'] + + # Skip eliminated organisms (Option D: they keep rel_abund = 0) + if has_reassignment_status and row['reassignment_status'] == 'eliminated': + continue + genome_sig = row['genome_sketch'] kmers_lost_count = 0 @@ -434,7 +583,11 @@ def estimate_relative_abundance( total_coverage += sample_hashes[kmer] final_stats_df.at[idx, 'kmers_lost'] = kmers_lost_count - final_stats_df.at[idx, 'rel_abund'] = total_coverage + + # Normalizes coverage by genome size to avoid bias toward larger genomes + # (genome size estimated as num_kmers * the scale factor from sourmash) + genome_size = len(genome_sig.minhash.hashes) * genome_sig.minhash.scaled + final_stats_df.at[idx, 'rel_abund'] = total_coverage / genome_size if genome_size > 0 else 0.0 # Normalize relative abundance to sum to 1.0 across all organisms total_abundance = final_stats_df['rel_abund'].sum() @@ -559,6 +712,8 @@ def hypothesis_recovery( num_threads: int = 16, winner_takes_all: bool = False, batch_size: int = 1000, + two_pass: bool = True, + calculate_coverage: bool = False, ): """ Go through each of the organisms that have non-zero overlap with the sample and perform a hypothesis test for the @@ -582,6 +737,8 @@ def hypothesis_recovery( :param significance: significance level for the hypothesis test :param ani_thresh: threshold for ANI (i.e. how similar do the genomes need to be in order to be considered the same) :param num_threads: number of threads to use for parallelization + :param two_pass: bool (use sylph's two-pass approach for winner-takes-all, default: True) + :param calculate_coverage: bool (use calculated coverage per organism instead of min_coverage_list, default: False) :return: a list of pandas dataframe with the results of the hypothesis tests based on different min_coverage values """ @@ -614,7 +771,7 @@ def hypothesis_recovery( # Get the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample exclusive_hashes_info, manifest, final_stats_df = get_exclusive_hashes( manifest, nontrivial_organism_names, sample_sig, ksize, path_to_genome_temp_dir, - num_threads, winner_takes_all, batch_size + num_threads, winner_takes_all, batch_size, two_pass ) # Set up the results dataframe columns @@ -637,8 +794,44 @@ def hypothesis_recovery( # Using multiprocessing.Pool to parallelize the execution manifest_list = [] - for min_coverage in tqdm(min_coverage_list, desc="Computing hypothesis recovery"): - logger.info(f"Computing hypothesis recovery for min_coverage={min_coverage}") + + if calculate_coverage: + # CALCULATE_COVERAGE MODE: Use calculated coverage (final_est_cov) per organism + logger.info("Using calculate_coverage mode: applying calculated coverage per organism") + + # Build a mapping from organism_name to calculated coverage (final_est_cov) + # Coverage depth (lambda) is converted to detection fraction using Poisson probability: + # P(a k-mer is detected) = 1 - exp(-lambda) + # This gives the expected fraction of k-mers that will be observed at least once. + coverage_map = {} + for _, row in final_stats_df.iterrows(): + org_name = row['organism_name'] + cov_val = row['final_est_cov'] + median_cov = row['median_cov'] + + if pd.notna(cov_val) and cov_val > 0: + # Primary choicev: use final_est_cov (lambda) with Poisson detection probability + # Convert depth to the expected fraction of k-mers detected + coverage_map[org_name] = 1.0 - np.exp(-cov_val) + elif pd.notna(median_cov) and median_cov > 0: + # Fallback: use median_cov when lambda estimation failed + # This helps detect low-abundance taxa where lambda couldn't be estimated + # (e.g., fewer than 25 non-zero coverage values) + coverage_map[org_name] = 1.0 - np.exp(-median_cov) + logger.warning(f"No valid lambda for {org_name}, using median_cov={median_cov:.3f} " + f"(detection fraction: {coverage_map[org_name]:.3f})") + else: + # Last resort: if neither is available, use strictest test + coverage_map[org_name] = 1.0 + logger.warning(f"No valid coverage data for {org_name}, using default 1.0") + + # Get organism names in manifest order (aligned with exclusive_hashes_info) + organism_names = manifest["organism_name"].to_list() + + # Build per-organism coverage list aligned with exclusive_hashes_info + per_organism_coverage = [coverage_map.get(name, 1.0) for name in organism_names] + + # Run hypothesis test with per-organism coverage with Pool(processes=num_threads) as p: params = ( ( @@ -646,20 +839,45 @@ def hypothesis_recovery( ksize, significance, ani_thresh, - min_coverage, + per_organism_coverage[i], # Per-organism coverage ) for i in range(len(exclusive_hashes_info)) ) results = p.starmap(single_hyp_test, params) - logger.info(f"Finished computing all results for min_coverage value: {min_coverage}") + logger.info("Finished computing hypothesis recovery with calculate_coverage") - # Create a pandas dataframe to store the results + # Create results DataFrame results = pd.DataFrame(results, columns=given_columns) - # combine the results with the manifest - manifest["min_coverage"] = min_coverage + # Add per-organism coverage to manifest (replaces the fixed min_coverage column) + manifest["min_coverage"] = per_organism_coverage manifest_list.append(pd.concat([manifest, results], axis=1)) + else: + # ORIGINAL MODE: Loop over user-supplied min_coverage_list + for min_coverage in tqdm(min_coverage_list, desc="Computing hypothesis recovery"): + logger.info(f"Computing hypothesis recovery for min_coverage={min_coverage}") + with Pool(processes=num_threads) as p: + params = ( + ( + exclusive_hashes_info[i], + ksize, + significance, + ani_thresh, + min_coverage, + ) + for i in range(len(exclusive_hashes_info)) + ) + results = p.starmap(single_hyp_test, params) + logger.info(f"Finished computing all results for min_coverage value: {min_coverage}") + + # Create a pandas dataframe to store the results + results = pd.DataFrame(results, columns=given_columns) + + # combine the results with the manifest + manifest["min_coverage"] = min_coverage + manifest_list.append(pd.concat([manifest, results], axis=1)) + # Merge coverage statistics into each manifest DataFrame # Select key coverage columns to include in output (including winner_map results) coverage_cols = [ From 832faa9b1934a55a5e35203d3b1114f1170b754c Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 13 Mar 2026 14:14:22 -0400 Subject: [PATCH 27/41] Various updates to run_YACHT, including adding yacht-level arugments --calculate-coverage and --no_two_pass. --- src/yacht/run_YACHT.py | 93 ++++++++++++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 25 deletions(-) diff --git a/src/yacht/run_YACHT.py b/src/yacht/run_YACHT.py index 1bf5c960..56c0a516 100644 --- a/src/yacht/run_YACHT.py +++ b/src/yacht/run_YACHT.py @@ -60,6 +60,14 @@ def add_arguments(parser): "Only used with --winner_takes_all. Default: 1000", default=1000, ) + parser.add_argument( + "--no_two_pass", + action="store_true", + help="Disable sylph's two-pass winner-takes-all approach. Uses original one-pass method. " + "Two-pass (default) is more accurate for closely related organisms but slower. " + "Only applies when --winner_takes_all is enabled.", + default=False, + ) parser.add_argument( "--keep_raw", action="store_true", help="Keep raw results in output file." ) @@ -74,9 +82,19 @@ def add_arguments(parser): type=float, help="A list of percentages of unique k-mers covered by reads in the sample. " "Each value should be between 0 and 1, with 0 being the most sensitive (and least " - "precise) and 1 being the most precise (and least sensitive).", + "precise) and 1 being the most precise (and least sensitive). " + "Default: [1, 0.5, 0.1, 0.05, 0.01]", required=False, - default=[1, 0.5, 0.1, 0.05, 0.01], + default=None, + ) + parser.add_argument( + "--calculate_coverage", + action="store_true", + help="Automatically calculate coverage for each organism using the sylph coverage model " + "(lambda estimation) instead of using user-supplied min_coverage_list values. " + "When enabled, produces a single output sheet with per-organism calculated coverage. " + "Cannot be used together with --min_coverage_list.", + default=False, ) parser.add_argument( "--out", @@ -94,10 +112,25 @@ def main(args): num_threads = args.num_threads # Number of threads to use for parallelization. winner_takes_all = args.winner_takes_all # Enable winner-takes-all k-mer reassignment batch_size = args.batch_size # Batch size for winner-takes-all processing + two_pass = not args.no_two_pass # Use sylph's two-pass approach (default: True) keep_raw = args.keep_raw # Keep raw results in output file. show_all = args.show_all # Show all organisms (no matter if present) in output file. - min_coverage_list = args.min_coverage_list # a list of percentages of unique k-mers covered by reads in the sample. + calculate_coverage = args.calculate_coverage # Use calculated coverage instead of user-supplied list out = str(Path(args.out).absolute()) # full path to output excel file + + # Validate mutual exclusivity of --calculate_coverage and --min_coverage_list + if calculate_coverage and args.min_coverage_list is not None: + raise ValueError( + "--calculate_coverage and --min_coverage_list cannot be used together. " + "Use --calculate_coverage for automatic per-organism coverage calculation, " + "or --min_coverage_list for manual coverage thresholds." + ) + + # Set default min_coverage_list if not provided and not using calculate_coverage + if args.min_coverage_list is None: + min_coverage_list = [1, 0.5, 0.1, 0.05, 0.01] # Default values + else: + min_coverage_list = args.min_coverage_list outdir = os.path.dirname(out) # path to output directory out_filename = os.path.basename(out) # output filename @@ -218,6 +251,8 @@ def main(args): num_threads, winner_takes_all, batch_size, + two_pass, + calculate_coverage, ) # remove unnecessary columns @@ -237,29 +272,37 @@ def main(args): logger.info(f"Saving results to {outdir}.") # save the results with different min_coverage with pd.ExcelWriter(out, engine="openpyxl", mode="w") as writer: - # save the raw results (i.e., min_coverage=1.0) - if keep_raw: - temp_mainifest = manifest_list[0].copy() - temp_mainifest.rename( - columns={ - "acceptance_threshold_with_coverage": "acceptance_threshold_wo_coverage", - "actual_confidence_with_coverage": "actual_confidence_wo_coverage", - "alt_confidence_mut_rate_with_coverage": "alt_confidence_mut_rate_wo_coverage", - }, - inplace=True, - ) - temp_mainifest.to_excel(writer, sheet_name="raw_result", index=False) - # save the results with different min_coverage given by the user - if not has_raw: - min_coverage_list = min_coverage_list[1:] - manifest_list = manifest_list[1:] - - for min_coverage, temp_mainifest in zip(min_coverage_list, manifest_list): + if calculate_coverage: + # Calculate coverage mode: single sheet with per-organism calculated coverage + temp_manifest = manifest_list[0].copy() if not show_all: - temp_mainifest = temp_mainifest[temp_mainifest["in_sample_est"] == True] - temp_mainifest.to_excel( - writer, sheet_name=f"min_coverage{min_coverage}", index=False - ) + temp_manifest = temp_manifest[temp_manifest["in_sample_est"] == True] + temp_manifest.to_excel(writer, sheet_name="calculated_coverage", index=False) + else: + # Original behavior: multiple sheets based on min_coverage_list + # save the raw results (i.e., min_coverage=1.0) + if keep_raw: + temp_mainifest = manifest_list[0].copy() + temp_mainifest.rename( + columns={ + "acceptance_threshold_with_coverage": "acceptance_threshold_wo_coverage", + "actual_confidence_with_coverage": "actual_confidence_wo_coverage", + "alt_confidence_mut_rate_with_coverage": "alt_confidence_mut_rate_wo_coverage", + }, + inplace=True, + ) + temp_mainifest.to_excel(writer, sheet_name="raw_result", index=False) + # save the results with different min_coverage given by the user + if not has_raw: + min_coverage_list = min_coverage_list[1:] + manifest_list = manifest_list[1:] + + for min_coverage, temp_mainifest in zip(min_coverage_list, manifest_list): + if not show_all: + temp_mainifest = temp_mainifest[temp_mainifest["in_sample_est"] == True] + temp_mainifest.to_excel( + writer, sheet_name=f"min_coverage{min_coverage}", index=False + ) if __name__ == "__main__": From 8fcc6379eab53cf5bc5d66f2d4cfbba2cdf6496f Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Tue, 7 Apr 2026 17:07:10 -0400 Subject: [PATCH 28/41] Removed excessive comments in the threshold filtering step. --- src/yacht/hypothesis_recovery_src.py | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index 3a7da990..ad7c5b4e 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -903,29 +903,7 @@ def hypothesis_recovery( how='left' # Keep all organisms, even those without coverage stats ) - # ============================================================================ - # ANI Threshold Filtering - # ============================================================================ - # - # After winner_map k-mer reassignment, filter organisms with low ANI. - # - # Why filter? - # - Removes poor matches (distant relatives, contamination, low complexity) - # - Improves result quality by eliminating noise - # - Standard practice in metagenomic profiling - # - # Threshold: MIN_ANI_THRESHOLD = 0.90 (90% ANI) - # - Matches sylph's MIN_ANI_DEF default - # - 90% ANI commonly used for genus-level distinction - # - Well-supported by microbial genomics literature - # - # Implementation: - # Keep organisms with final_est_ani >= 0.90 OR NaN ANI - # (NaN kept because hypothesis test may still be valid via exclusive k-mers) - # - # Customization: - # To change threshold, modify MIN_ANI_THRESHOLD in utils.py - # + # ANI threshold filtering logger.info(f"Filtering organisms with final_est_ani < {MIN_ANI_THRESHOLD} ({MIN_ANI_THRESHOLD*100:.0f}% ANI)") for i in range(len(manifest_list)): initial_count = len(manifest_list[i]) From 7e89deaf1467acf9e4274ef4f6669fc6ec1342de Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Tue, 7 Apr 2026 17:14:45 -0400 Subject: [PATCH 29/41] Bookmarked code to fix for tomorrow. --- src/yacht/hypothesis_recovery_src.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index ad7c5b4e..720abc3f 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -904,6 +904,7 @@ def hypothesis_recovery( ) # ANI threshold filtering + # start here 4/8 logger.info(f"Filtering organisms with final_est_ani < {MIN_ANI_THRESHOLD} ({MIN_ANI_THRESHOLD*100:.0f}% ANI)") for i in range(len(manifest_list)): initial_count = len(manifest_list[i]) From 58e17fe199a3d61dc2bf7cba832df272ed0c8899 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Thu, 9 Apr 2026 12:56:07 -0400 Subject: [PATCH 30/41] Updates for code review. --- src/yacht/hypothesis_recovery_src.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index 720abc3f..e6b10f03 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -685,9 +685,7 @@ def single_hyp_test( in_sample_est = (num_matches >= acceptance_threshold_with_coverage) and ( num_matches != 0 ) - # return in_sample_est, p_val, num_exclusive_kmers, num_exclusive_kmers_coverage, num_matches, \ - # acceptance_threshold_wo_coverage, acceptance_threshold_with_coverage, actual_confidence_wo_coverage, \ - # actual_confidence_with_coverage, alt_confidence_mut_rate, alt_confidence_mut_rate_with_coverage + return ( in_sample_est, p_val, @@ -904,7 +902,6 @@ def hypothesis_recovery( ) # ANI threshold filtering - # start here 4/8 logger.info(f"Filtering organisms with final_est_ani < {MIN_ANI_THRESHOLD} ({MIN_ANI_THRESHOLD*100:.0f}% ANI)") for i in range(len(manifest_list)): initial_count = len(manifest_list[i]) @@ -916,5 +913,21 @@ def hypothesis_recovery( filtered_count = initial_count - len(manifest_list[i]) if filtered_count > 0: logger.info(f" Filtered {filtered_count} organisms below ANI threshold from min_coverage={manifest_list[i]['min_coverage'].iloc[0] if len(manifest_list[i]) > 0 else 'N/A'} results") - - return manifest_list + post_filtered_df = manifest_list[0] + total_abundance = post_filtered_df['rel_abund'].sum() + print(f"Total abundance") + print(total_abundance) + print(f"Relative abundance") + print(post_filtered_df['rel_abund']) + #post_filtered_df['rel_abund'] = post_filtered_df['rel_abund'] / total_abundance + for i in manifest_list: + manifest_list[i].loc[:, 'rel_abund'] = manifest_list[i]['rel_abund'] / total_abundance + test = post_filtered_df['rel_abund'] / total_abundance + print(f"Test vector") + print(test) + print(f"Full df post-filtered") + print(post_filtered_df) + logger.info(f"Relative abundance normalized (total coverage: {total_abundance:.2f})") + + + return manifest_list \ No newline at end of file From 5aad63011bfc068562ea01a2f867edc5dad879ec Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Thu, 9 Apr 2026 17:23:46 -0400 Subject: [PATCH 31/41] Draft re-normalization procedure. --- src/yacht/hypothesis_recovery_src.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index e6b10f03..64d70199 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -913,21 +913,15 @@ def hypothesis_recovery( filtered_count = initial_count - len(manifest_list[i]) if filtered_count > 0: logger.info(f" Filtered {filtered_count} organisms below ANI threshold from min_coverage={manifest_list[i]['min_coverage'].iloc[0] if len(manifest_list[i]) > 0 else 'N/A'} results") - post_filtered_df = manifest_list[0] - total_abundance = post_filtered_df['rel_abund'].sum() - print(f"Total abundance") - print(total_abundance) - print(f"Relative abundance") - print(post_filtered_df['rel_abund']) #post_filtered_df['rel_abund'] = post_filtered_df['rel_abund'] / total_abundance - for i in manifest_list: - manifest_list[i].loc[:, 'rel_abund'] = manifest_list[i]['rel_abund'] / total_abundance - test = post_filtered_df['rel_abund'] / total_abundance - print(f"Test vector") - print(test) - print(f"Full df post-filtered") - print(post_filtered_df) - logger.info(f"Relative abundance normalized (total coverage: {total_abundance:.2f})") - - + # Re-normalizing, regardless of filter results (i.e. filtered_count) + post_filtered_df = manifest_list[i] + total_abundance = post_filtered_df['rel_abund'].sum() + print(f"Total abundance") + print(total_abundance) + if total_abundance > 0: + manifest_list[i].loc[:, 'rel_abund'] = manifest_list[i]['rel_abund'] / total_abundance + logger.info(f"Relative abundance normalized (total coverage: {total_abundance:.2f}.)") + else: + logger.warning(f"No relative abundance was done after ANI filtering.") return manifest_list \ No newline at end of file From e7f06ddf224f0c63fda6bd934b693dd3e807aa8d Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 10 Apr 2026 12:47:07 -0400 Subject: [PATCH 32/41] Completed re-normalization procedure for rel_abund column in output. --- src/yacht/hypothesis_recovery_src.py | 4 +--- src/yacht/run_YACHT.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index 64d70199..f4c86a1e 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -917,11 +917,9 @@ def hypothesis_recovery( # Re-normalizing, regardless of filter results (i.e. filtered_count) post_filtered_df = manifest_list[i] total_abundance = post_filtered_df['rel_abund'].sum() - print(f"Total abundance") - print(total_abundance) if total_abundance > 0: manifest_list[i].loc[:, 'rel_abund'] = manifest_list[i]['rel_abund'] / total_abundance logger.info(f"Relative abundance normalized (total coverage: {total_abundance:.2f}.)") else: logger.warning(f"No relative abundance was done after ANI filtering.") - return manifest_list \ No newline at end of file + return manifest_list diff --git a/src/yacht/run_YACHT.py b/src/yacht/run_YACHT.py index 56c0a516..6a399f02 100644 --- a/src/yacht/run_YACHT.py +++ b/src/yacht/run_YACHT.py @@ -277,13 +277,19 @@ def main(args): temp_manifest = manifest_list[0].copy() if not show_all: temp_manifest = temp_manifest[temp_manifest["in_sample_est"] == True] + # Adding re-normalization here also + total_abundance = temp_manifest['rel_abund'].sum() + if total_abundance > 0: + temp_manifest.loc[:, 'rel_abund'] = temp_manifest['rel_abund'] / total_abundance + print(f"Second re-normalization") + print(temp_manifest['rel_abund'].sum()) temp_manifest.to_excel(writer, sheet_name="calculated_coverage", index=False) else: # Original behavior: multiple sheets based on min_coverage_list # save the raw results (i.e., min_coverage=1.0) if keep_raw: - temp_mainifest = manifest_list[0].copy() - temp_mainifest.rename( + temp_manifest = manifest_list[0].copy() + temp_manifest.rename( columns={ "acceptance_threshold_with_coverage": "acceptance_threshold_wo_coverage", "actual_confidence_with_coverage": "actual_confidence_wo_coverage", @@ -291,16 +297,16 @@ def main(args): }, inplace=True, ) - temp_mainifest.to_excel(writer, sheet_name="raw_result", index=False) + temp_manifest.to_excel(writer, sheet_name="raw_result", index=False) # save the results with different min_coverage given by the user if not has_raw: min_coverage_list = min_coverage_list[1:] manifest_list = manifest_list[1:] - for min_coverage, temp_mainifest in zip(min_coverage_list, manifest_list): + for min_coverage, temp_manifest in zip(min_coverage_list, manifest_list): if not show_all: - temp_mainifest = temp_mainifest[temp_mainifest["in_sample_est"] == True] - temp_mainifest.to_excel( + temp_manifest = temp_manifest[temp_manifest["in_sample_est"] == True] + temp_manifest.to_excel( writer, sheet_name=f"min_coverage{min_coverage}", index=False ) From edf75dd8316e3092a590dffc3f047cb94437c201 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 10 Apr 2026 15:02:58 -0400 Subject: [PATCH 33/41] Added re-norm for times when wta is on but calc. cov. is off. --- src/yacht/run_YACHT.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/yacht/run_YACHT.py b/src/yacht/run_YACHT.py index 6a399f02..eb1078a1 100644 --- a/src/yacht/run_YACHT.py +++ b/src/yacht/run_YACHT.py @@ -306,6 +306,10 @@ def main(args): for min_coverage, temp_manifest in zip(min_coverage_list, manifest_list): if not show_all: temp_manifest = temp_manifest[temp_manifest["in_sample_est"] == True] + #adding renormilization for the original behavior + total_abundance = temp_manifest['rel_abund'].sum() + if total_abundance > 0: + temp_manifest.loc[:, 'rel_abund'] = temp_manifest['rel_abund'] / total_abundance temp_manifest.to_excel( writer, sheet_name=f"min_coverage{min_coverage}", index=False ) From 449cab342fcbf9e50f127ad35960eaafb175e992 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 10 Apr 2026 15:19:04 -0400 Subject: [PATCH 34/41] Fixed tab issue. --- src/yacht/run_YACHT.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/yacht/run_YACHT.py b/src/yacht/run_YACHT.py index eb1078a1..b2633414 100644 --- a/src/yacht/run_YACHT.py +++ b/src/yacht/run_YACHT.py @@ -307,7 +307,7 @@ def main(args): if not show_all: temp_manifest = temp_manifest[temp_manifest["in_sample_est"] == True] #adding renormilization for the original behavior - total_abundance = temp_manifest['rel_abund'].sum() + total_abundance = temp_manifest['rel_abund'].sum() if total_abundance > 0: temp_manifest.loc[:, 'rel_abund'] = temp_manifest['rel_abund'] / total_abundance temp_manifest.to_excel( From ffeebe8b31b0fada2bfcf54783910ddd86101878 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Fri, 10 Apr 2026 15:43:08 -0400 Subject: [PATCH 35/41] Removed print statements. --- src/yacht/run_YACHT.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/yacht/run_YACHT.py b/src/yacht/run_YACHT.py index b2633414..e037b36a 100644 --- a/src/yacht/run_YACHT.py +++ b/src/yacht/run_YACHT.py @@ -281,8 +281,6 @@ def main(args): total_abundance = temp_manifest['rel_abund'].sum() if total_abundance > 0: temp_manifest.loc[:, 'rel_abund'] = temp_manifest['rel_abund'] / total_abundance - print(f"Second re-normalization") - print(temp_manifest['rel_abund'].sum()) temp_manifest.to_excel(writer, sheet_name="calculated_coverage", index=False) else: # Original behavior: multiple sheets based on min_coverage_list From 8a6e065d9b6608a09aa0bfebb81ac1c180b4dc13 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Tue, 14 Apr 2026 15:58:02 -0400 Subject: [PATCH 36/41] Small change to how organism_id_list is created to accomodate gtdb genomes during yacht convert. --- src/yacht/standardize_yacht_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/yacht/standardize_yacht_output.py b/src/yacht/standardize_yacht_output.py index 8fdac045..822cf763 100644 --- a/src/yacht/standardize_yacht_output.py +++ b/src/yacht/standardize_yacht_output.py @@ -247,7 +247,7 @@ def __to_cami(self, sample_name): ## select the organisms that YACHT considers to present in the sample yacht_res_df = self.yacht_output.copy() - organism_id_list = yacht_res_df["organism_name"].tolist() + organism_id_list = yacht_res_df["organism_name"].str.split().str[0].tolist() if len(organism_id_list) == 0: logger.error("No organism is detected by YACHT.") From 85d05be280a909bdc9645f55e59af17c1e8aa20a Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Wed, 15 Apr 2026 10:57:21 -0400 Subject: [PATCH 37/41] Improvements to yacht convert functionality with cov_calc output. --- src/yacht/standardize_yacht_output.py | 44 +++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/src/yacht/standardize_yacht_output.py b/src/yacht/standardize_yacht_output.py index 822cf763..4effe10d 100644 --- a/src/yacht/standardize_yacht_output.py +++ b/src/yacht/standardize_yacht_output.py @@ -32,8 +32,18 @@ def add_arguments(parser): parser.add_argument( "--sheet_name", type=str, - help="The sheet name of the YACHT output excel file.", - required=True, + help="The sheet name of the YACHT output excel file. " + "Required unless --single_sheet is used.", + required=False, + default=None, + ) + parser.add_argument( + "--single_sheet", + action="store_true", + default=False, + help="Use this flag when the input was produced with --calculate_coverage. " + "Automatically selects the 'calculated_coverage' sheet. " + "Cannot be used together with --sheet_name.", ) parser.add_argument( "--genome_to_taxid", @@ -70,13 +80,24 @@ def add_arguments(parser): def main(args): yacht_output = args.yacht_output - sheet_name = args.sheet_name genome_to_taxid = args.genome_to_taxid mode = args.mode sample_name = args.sample_name outfile_prefix = args.outfile_prefix outdir = args.outdir + # resolves sheet name from --sheet_name or --single_sheet + if args.single_sheet and args.sheet_name: + logger.error("--single_sheet and --sheet_name are mutually exclusive.") + raise ValueError + elif args.single_sheet: + sheet_name = "calculated_coverage" + elif args.sheet_name: + sheet_name = args.sheet_name + else: + logger.error("One of either --sheet_name or --single_sheet must be provided.") + raise ValueError + # check if the yacht output file exists if not os.path.exists(yacht_output): logger.error(f"{yacht_output} does not exist.") @@ -92,9 +113,20 @@ def main(args): os.makedirs(outdir) # load the yacht output - yacht_output_df = pd.read_excel( - yacht_output, sheet_name=sheet_name, engine="openpyxl" - ) + try: + yacht_output_df = pd.read_excel( + yacht_output, sheet_name=sheet_name, engine="openpyxl" + ) + except ValueError as e: + if args.single_sheet: + logger.error( + f"Sheet 'calculated_coverage' not found in {yacht_output}. " + "This sheet is only present when YACHT was run with --calculate_coverage. " + "If you used a min_coverage list instead, use --sheet_name to specify the sheet." + ) + else: + logger.error(f"Sheet '{sheet_name}' not found in {yacht_output}: {e}") + raise # converet the first column to string yacht_output_df["organism_name"] = yacht_output_df["organism_name"].astype(str) From 9afbc8aaf5ed8be9176c5a15d19471d989bb672f Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn" Date: Wed, 15 Apr 2026 16:48:45 -0400 Subject: [PATCH 38/41] Replaced counts with weights for yacht convert when paired with calculate_coverage sheets. Added various handling improvements, including a clear help message. --- src/yacht/standardize_yacht_output.py | 55 +++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/src/yacht/standardize_yacht_output.py b/src/yacht/standardize_yacht_output.py index 4effe10d..26661af6 100644 --- a/src/yacht/standardize_yacht_output.py +++ b/src/yacht/standardize_yacht_output.py @@ -41,9 +41,12 @@ def add_arguments(parser): "--single_sheet", action="store_true", default=False, - help="Use this flag when the input was produced with --calculate_coverage. " - "Automatically selects the 'calculated_coverage' sheet. " - "Cannot be used together with --sheet_name.", + help="Automatically selects the 'calculated_coverage' sheet produced when " + "yacht run is invoked with --calculate_coverage. For full relative " + "abundance percentages, yacht run must also have been called with " + "--winner_takes_all; otherwise percentages revert to count-based. " + "Mutually exclusive with --sheet_name." + ) parser.add_argument( "--genome_to_taxid", @@ -126,7 +129,7 @@ def main(args): ) else: logger.error(f"Sheet '{sheet_name}' not found in {yacht_output}: {e}") - raise + sys.exit(1) # converet the first column to string yacht_output_df["organism_name"] = yacht_output_df["organism_name"].astype(str) @@ -290,6 +293,40 @@ def __to_cami(self, sample_name): "genome_id in @organism_id_list" ).reset_index(drop=True) + ## Determines per-organism weights for percentage calculation. + use_rel_abund = ( + "rel_abund" in self.yacht_output.columns + and self.yacht_output["rel_abund"].notna().any() + ) + if "rel_abund" in self.yacht_output.columns and not use_rel_abund: + logger.warning( + "rel_abund column is present but empty — YACHT run was likely executed without " + "--winner_takes_all." + ) + if use_rel_abund: + genome_id_set = set(selected_organism_metadata_df["genome_id"]) + org_weights = ( + self.yacht_output + .assign(_key=self.yacht_output["organism_name"].str.split().str[0]) + .set_index("_key")["rel_abund"] + .reindex(sorted(genome_id_set), fill_value=0.0) + .fillna(0.0) + ) + total_weight = org_weights.sum() + if total_weight == 0: + logger.warning( + "Sum of rel_abund is zero; reverting to count-based percentages." + ) + use_rel_abund = False + else: + weight_lookup = org_weights.to_dict() + elif not use_rel_abund: + logger.warning( + "Falling back to count-based percentages." + ) + total_weight = float(len(selected_organism_metadata_df)) + weight_lookup = None + ## Summarize the results summary_dict = {} for row in selected_organism_metadata_df.to_numpy(): @@ -301,6 +338,7 @@ def __to_cami(self, sample_name): taxid_list = list(np.array(row[3].split("|"))[select_index]) lineage_list = list(np.array(row[4].split("|"))[select_index]) rank_list = list(np.array(row[5].split("|"))[select_index]) + weight = weight_lookup.get(row[0], 0.0) if use_rel_abund else 1.0 current_lineage = "" current_taxid = "" for index, (taxid, rank, lineage) in enumerate( @@ -317,15 +355,15 @@ def __to_cami(self, sample_name): "RANK": rank, "TAXPATH": current_taxid, "TAXPATHSN": current_lineage, - "count": 1, + "weight": weight, } else: - summary_dict[taxid]["count"] += 1 + summary_dict[taxid]["weight"] += weight # calculate percentage for taxid in summary_dict: summary_dict[taxid]["PERCENTAGE"] = ( - summary_dict[taxid]["count"] / len(selected_organism_metadata_df) * 100 + summary_dict[taxid]["weight"] / total_weight * 100 ) ## sort by rank in allowable rank list @@ -335,8 +373,9 @@ def __to_cami(self, sample_name): .rename(columns={"index": "TAXID"}) ) res_df = [summary_df.query(f'RANK == "{rank}"') for rank in self.allowable_rank] - res_df = pd.concat(res_df).drop(columns=["count"]).reset_index(drop=True) + res_df = pd.concat(res_df).drop(columns=["weight"]).reset_index(drop=True) res_df.columns = ["@@TAXID", "RANK", "TAXPATH", "TAXPATHSN", "PERCENTAGE"] + res_df["PERCENTAGE"] = res_df["PERCENTAGE"].round(6) ## output summary results if len(res_df) == 0: From eb7632990ead191a6396695c24afc0cfbe8a77b5 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn." Date: Wed, 29 Apr 2026 14:10:05 -0400 Subject: [PATCH 39/41] Version bump to 2.1.0 --- src/yacht/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/yacht/utils.py b/src/yacht/utils.py index 9abb2a1b..c4f5f0dd 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -39,7 +39,7 @@ ksize: int = 31 # Note: hard-coding this for now # Set up global variables -__version__ = "2.0.1" +__version__ = "2.1.0" GITHUB_API_URL = "https://api.github.com/repos/KoslickiLab/YACHT/contents/demo/{path}" GITHUB_RAW_URL = "https://raw.githubusercontent.com/KoslickiLab/YACHT/main/demo/{path}" BASE_URL = "https://farm.cse.ucdavis.edu/~ctbrown/sourmash-db/" From 97e1d59f71ab72168fa2a808690e1053d1b0491a Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn." Date: Wed, 29 Apr 2026 14:34:41 -0400 Subject: [PATCH 40/41] Added --min_ani flag, N-R convergence mode, and introduced sample-wide median lambda as fallback coverage. --- src/yacht/cov_calc.py | 16 +++--- src/yacht/hypothesis_recovery_src.py | 84 ++++++++++++++++++---------- src/yacht/run_YACHT.py | 28 ++++++++++ src/yacht/utils.py | 30 ++++++---- 4 files changed, 110 insertions(+), 48 deletions(-) diff --git a/src/yacht/cov_calc.py b/src/yacht/cov_calc.py index 3bbbe441..89cb3900 100644 --- a/src/yacht/cov_calc.py +++ b/src/yacht/cov_calc.py @@ -21,7 +21,7 @@ winner_map = None #skipping this step in this version kmers_lost_count = None -def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.SourmashSignature): +def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.SourmashSignature, convergence_nr: bool = True): """ Function that calculates lambda according to Shaw and Yu (2024) from two sourmash.Minshash files (resresenting the sample and the genome sketches). """ @@ -77,12 +77,12 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma break # Check if cov_max remains inf (i.e. no valid maximum found) - if cov_max == float('inf'): - logger.warning( - f"Could not determine valid coverage maximum for genome {genome_sig.name}. " - f"Median coverage: {median_cov}. Returning None." - ) - return None + if cov_max == float('inf'): + logger.debug( + f"No coverage outliers found for genome {genome_sig.name} " + f"(median_cov={median_cov}). Retaining all coverage values (cov_max=inf), " + ) + # cov_max remains float('inf'), so all covs pass the filter below; consistent with sylph behavior full_covs = [0] * (len(gn_hashes) - contain_count) @@ -106,7 +106,7 @@ def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.Sourma elif (myArgs.bin == True): test_lambda = binary_search_lambda(full_covs) elif (myArgs.mle) == True: - test_lambda = mle_zip(full_covs, gn_kmers_items) + test_lambda = mle_zip(full_covs, gn_kmers_items, convergence_nr) else: test_lambda = ratio_lambda(full_covs, MIN_COUNT_THRESH) diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index f4c86a1e..980c4c0a 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -14,7 +14,6 @@ from .utils import ( load_signature_with_ksize, decompress_all_sig_files, - MIN_ANI_THRESHOLD, ratio_lambda, ani_from_lambda, MIN_COUNT_THRESH, @@ -127,17 +126,20 @@ def get_organisms_with_nonzero_overlap( return multisearch_result["match_name"].to_list() -# Global variable for sharing sample signature across worker processes +# Global variables for sharing state across worker processes _worker_sample_sig = None +_worker_convergence_nr = True -def _init_coverage_worker(sample_sig): +def _init_coverage_worker(sample_sig, convergence_nr): """ - Initializer for worker processes to set up shared sample signature. + Initializer for worker processes to set up shared sample signature and convergence flag. :param sample_sig: Sample signature to be shared across all workers + :param convergence_nr: Whether to use convergence criterion in Newton-Raphson """ - global _worker_sample_sig + global _worker_sample_sig, _worker_convergence_nr _worker_sample_sig = sample_sig + _worker_convergence_nr = convergence_nr def _calculate_coverage_worker(args): """ @@ -153,7 +155,7 @@ def _calculate_coverage_worker(args): os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), ksize, ) - result_df = cov_calc(_worker_sample_sig, sig) + result_df = cov_calc(_worker_sample_sig, sig, _worker_convergence_nr) if result_df is not None: # Add organism_name to the result for proper matching result_df['organism_name'] = organism_name @@ -173,6 +175,8 @@ def get_exclusive_hashes( winner_takes_all: bool = False, batch_size: int = 1000, two_pass: bool = True, + convergence_nr: bool = True, + min_ani: float = 0.95, ) -> Tuple[List[Tuple[int, int]], pd.DataFrame, pd.DataFrame]: """ This function gets the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample, and @@ -262,7 +266,7 @@ def __find_exclusive_hashes( chunk_size = max(1, len(organism_md5sum_list) // (num_threads * 50)) logger.info(f"Using chunk size of {chunk_size} for parallel processing") - with Pool(processes=num_threads, initializer=_init_coverage_worker, initargs=(sample_sig,)) as pool: + with Pool(processes=num_threads, initializer=_init_coverage_worker, initargs=(sample_sig, convergence_nr)) as pool: # Prepare arguments for parallel processing (sample_sig shared via initializer to avoid pickling overhead) # Include organism_name for proper matching (fixes misalignment bug from imap_unordered) organism_name_list = sub_manifest["organism_name"].to_list() @@ -310,7 +314,7 @@ def __find_exclusive_hashes( # Recalculate ANI using only won k-mers (prepares for Pass 2) final_stats_df = recalculate_ani_from_winner_map( - final_stats_df, winner_map, sample_sig, ksize, batch_size + final_stats_df, winner_map, sample_sig, ksize, batch_size, min_ani=min_ani ) # Pass 2: Rebuild winner map with refined ANI estimates @@ -412,14 +416,14 @@ def recalculate_ani_from_winner_map( winner_map: Dict[int, Tuple[float, str]], sample_sig: sourmash.SourmashSignature, ksize: int, - batch_size: int = 1000 + batch_size: int = 1000, + min_ani: float = 0.95 ) -> pd.DataFrame: """ Recalculates ANI for each organism using only k-mers it 'won' in the winner map. This implements Pass 2 of sylph's two-pass winner-takes-all approach. Organisms that lost all their k-mers are marked as 'eliminated' with rel_abund=0. - This is Option D handling: keep in results but flag as inconclusive. :param final_stats_df: DataFrame with coverage statistics including organism_name, final_est_ani, and genome_sketch columns @@ -466,19 +470,24 @@ def recalculate_ani_from_winner_map( if kmer in sample_hashes and sample_hashes[kmer] > 0: won_kmers_in_sample.append(sample_hashes[kmer]) - # Handle organisms that lost all k-mers (Option D: mark as eliminated) + # Handle organisms that lost all k-mers if total_won_kmers == 0: final_stats_df.at[idx, 'reassignment_status'] = 'eliminated' final_stats_df.at[idx, 'final_est_ani'] = float('nan') eliminated_count += 1 continue - # Check if we have enough data for lambda estimation + # Check if we have enough data for lambda estimation if len(won_kmers_in_sample) < SAMPLE_SIZE_CUTOFF: - # Not enough won k-mers in sample - mark as eliminated - final_stats_df.at[idx, 'reassignment_status'] = 'eliminated' - final_stats_df.at[idx, 'final_est_ani'] = float('nan') - eliminated_count += 1 + # Not enough won k-mers for reliable lambda re-estimation, but don't eliminate. + # Compute naive ANI from won k-mers; only update final_est_ani if the naive + # estimate is above threshold — otherwise retains the pre-WTA estimate. + if total_won_kmers > 0: + naive_won_ani = (len(won_kmers_in_sample) / total_won_kmers) ** (1 / ksize) + if naive_won_ani >= min_ani: + final_stats_df.at[idx, 'final_est_ani'] = naive_won_ani + # else: retain original pre-WTA final_est_ani + final_stats_df.at[idx, 'reassignment_status'] = 'lambda_failed' continue # Build full_cov array (zeros for won k-mers not in sample + coverages for those in sample) @@ -559,7 +568,7 @@ def estimate_relative_abundance( row = final_stats_df.iloc[idx] organism_name = row['organism_name'] - # Skip eliminated organisms (Option D: they keep rel_abund = 0) + # Skip eliminated organisms if has_reassignment_status and row['reassignment_status'] == 'eliminated': continue @@ -712,6 +721,8 @@ def hypothesis_recovery( batch_size: int = 1000, two_pass: bool = True, calculate_coverage: bool = False, + convergence_nr: bool = True, + min_ani: float = 0.95, ): """ Go through each of the organisms that have non-zero overlap with the sample and perform a hypothesis test for the @@ -769,7 +780,7 @@ def hypothesis_recovery( # Get the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample exclusive_hashes_info, manifest, final_stats_df = get_exclusive_hashes( manifest, nontrivial_organism_names, sample_sig, ksize, path_to_genome_temp_dir, - num_threads, winner_takes_all, batch_size, two_pass + num_threads, winner_takes_all, batch_size, two_pass, convergence_nr, min_ani ) # Set up the results dataframe columns @@ -793,6 +804,8 @@ def hypothesis_recovery( # Using multiprocessing.Pool to parallelize the execution manifest_list = [] + fallback_coverage = min(min_coverage_list) if min_coverage_list else 0.1 + if calculate_coverage: # CALCULATE_COVERAGE MODE: Use calculated coverage (final_est_cov) per organism logger.info("Using calculate_coverage mode: applying calculated coverage per organism") @@ -801,6 +814,20 @@ def hypothesis_recovery( # Coverage depth (lambda) is converted to detection fraction using Poisson probability: # P(a k-mer is detected) = 1 - exp(-lambda) # This gives the expected fraction of k-mers that will be observed at least once. + # Compute sample-wide median lambda for fallback coverage + valid_lambdas = [ + row['final_est_cov'] for _, row in final_stats_df.iterrows() + if pd.notna(row['final_est_cov']) and row['final_est_cov'] > 0 + ] + if valid_lambdas: + median_lambda = np.median(valid_lambdas) + fallback_coverage = 1.0 - np.exp(-median_lambda) + logger.info(f"Sample-wide median lambda: {median_lambda:.4f}, " + f"fallback detection fraction: {fallback_coverage:.4f}") + else: + fallback_coverage = 0.1 + logger.warning("No valid lambda estimates in sample; using fallback coverage 0.1") + coverage_map = {} for _, row in final_stats_df.iterrows(): org_name = row['organism_name'] @@ -812,16 +839,15 @@ def hypothesis_recovery( # Convert depth to the expected fraction of k-mers detected coverage_map[org_name] = 1.0 - np.exp(-cov_val) elif pd.notna(median_cov) and median_cov > 0: - # Fallback: use median_cov when lambda estimation failed - # This helps detect low-abundance taxa where lambda couldn't be estimated - # (e.g., fewer than 25 non-zero coverage values) - coverage_map[org_name] = 1.0 - np.exp(-median_cov) - logger.warning(f"No valid lambda for {org_name}, using median_cov={median_cov:.3f} " - f"(detection fraction: {coverage_map[org_name]:.3f})") + # Fallback: use sample-wide median lambda rather than per-organism + # median_cov, since median_cov at low coverage (e.g. 1x) gives an + # artificially strict detection fraction of (1 - e^-1) + coverage_map[org_name] = fallback_coverage + logger.warning(f"No valid lambda for {org_name}, using fallback_coverage={fallback_coverage:.4f}") else: - # Last resort: if neither is available, use strictest test - coverage_map[org_name] = 1.0 - logger.warning(f"No valid coverage data for {org_name}, using default 1.0") + # The last resort + coverage_map[org_name] = fallback_coverage + logger.warning(f"No valid coverage data for {org_name}, using fallback_coverage={fallback_coverage:.4f}") # Get organism names in manifest order (aligned with exclusive_hashes_info) organism_names = manifest["organism_name"].to_list() @@ -902,12 +928,12 @@ def hypothesis_recovery( ) # ANI threshold filtering - logger.info(f"Filtering organisms with final_est_ani < {MIN_ANI_THRESHOLD} ({MIN_ANI_THRESHOLD*100:.0f}% ANI)") + logger.info(f"Filtering organisms with final_est_ani < {min_ani} ({min_ani*100:.0f}% ANI)") for i in range(len(manifest_list)): initial_count = len(manifest_list[i]) # Keep organisms with ANI >= threshold OR organisms with no ANI estimate (NaN) manifest_list[i] = manifest_list[i][ - (manifest_list[i]['final_est_ani'] >= MIN_ANI_THRESHOLD) | + (manifest_list[i]['final_est_ani'] >= min_ani) | (manifest_list[i]['final_est_ani'].isna()) ].reset_index(drop=True) filtered_count = initial_count - len(manifest_list[i]) diff --git a/src/yacht/run_YACHT.py b/src/yacht/run_YACHT.py index e037b36a..0bd9bcc1 100644 --- a/src/yacht/run_YACHT.py +++ b/src/yacht/run_YACHT.py @@ -96,6 +96,25 @@ def add_arguments(parser): "Cannot be used together with --min_coverage_list.", default=False, ) + parser.add_argument( + "--convergence_nr", + action="store_true", + help="Turn on the convergence criterion in the Newton-Raphson lambda estimator, " + "terminating when the update falls below " + f"LAMBDA_EPSILON ({utils.LAMBDA_EPSILON}). " + "By default, all 1000 iterations run without terminating, " + "matching the original sylph behavior. ", + default=False, + ) + parser.add_argument( + "--min_ani", + type=float, + help="Minimum ANI threshold for retaining organisms in results. " + "Organisms whose final estimated ANI falls below this value are filtered out. " + "Default: 0.95 (species-level boundary).", + required=False, + default=0.95, + ) parser.add_argument( "--out", type=str, @@ -116,6 +135,13 @@ def main(args): keep_raw = args.keep_raw # Keep raw results in output file. show_all = args.show_all # Show all organisms (no matter if present) in output file. calculate_coverage = args.calculate_coverage # Use calculated coverage instead of user-supplied list + convergence_nr = args.convergence_nr # Use convergence criterion in Newton-Raphson (default: False) + min_ani = args.min_ani # Minimum ANI threshold for filtering organisms + + if not (0.90 <= min_ani <= 1): + raise ValueError( + f"--min_ani value {min_ani} must be between 0.90 (genus-level) and 1 (both inclusive)." + ) out = str(Path(args.out).absolute()) # full path to output excel file # Validate mutual exclusivity of --calculate_coverage and --min_coverage_list @@ -253,6 +279,8 @@ def main(args): batch_size, two_pass, calculate_coverage, + convergence_nr, + min_ani, ) # remove unnecessary columns diff --git a/src/yacht/utils.py b/src/yacht/utils.py index c4f5f0dd..9a86d6a2 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -31,7 +31,6 @@ # Sylph (Shaw and Yu, 2024) related constants SAMPLE_SIZE_CUTOFF: int = 25 PVALUE_CUTOFF: float = 0.9999999999 -MIN_ANI_THRESHOLD: float = 0.90 # Minimum ANI threshold for filtering organisms MEDIAN_ANI_THRESHOLD: float = 3.00 MAX_MEDIAN_FOR_MEAN_FINAL_EST: float = 15.0 MIN_COUNT_THRESH: int = 3 @@ -39,7 +38,7 @@ ksize: int = 31 # Note: hard-coding this for now # Set up global variables -__version__ = "2.1.0" +__version__ = "2.0.1" GITHUB_API_URL = "https://api.github.com/repos/KoslickiLab/YACHT/contents/demo/{path}" GITHUB_RAW_URL = "https://raw.githubusercontent.com/KoslickiLab/YACHT/main/demo/{path}" BASE_URL = "https://farm.cse.ucdavis.edu/~ctbrown/sourmash-db/" @@ -614,22 +613,31 @@ def load_one_sig(sig_path: str, ksize: int): ) return(loaded_sig) -def newton_raphson(ratio: float, mean: float): +def newton_raphson(ratio: float, mean: float, convergence: bool = True): """ Shaw and Yu (2024)'s implmentation of Newton-Raphson use to assist in the calculation of lambda. """ + ratio = min(ratio, 1.0 - LAMBDA_EPSILON) curr = mean / (1 - ratio) - - for _ in range(1000): #iterates to converge on an approximation for the root + + for _ in range(1000): t1 = (1 - ratio) * curr e_curr = math.exp(-curr) t2 = mean * (1 - e_curr) - t3 = 1 - ratio + t3 = 1 - ratio t4 = mean * e_curr - curr = curr - (t1 - t2) / (t3 - t4) + denom = t3 - t4 + if abs(denom) < LAMBDA_EPSILON: + break + prev = curr + curr = curr - (t1 - t2) / denom + if not math.isfinite(curr): + return None + if convergence and abs(curr - prev) < LAMBDA_EPSILON: + break return curr -def mle_zip(full_covs: list[int], _k: float): +def mle_zip(full_covs: list[int], _k: float, convergence: bool = True): """ Maximum likelihood estimator for the zero-inflated Poisson (ZIP) distribution from Shaw and Yu (2024) """ @@ -649,10 +657,10 @@ def mle_zip(full_covs: list[int], _k: float): return None mean = np.mean(full_covs) - nr_input = num_zero/len(full_covs) - lambda_out = newton_raphson(nr_input, mean) + nr_input = n_zero / len(full_covs) + lambda_out = newton_raphson(nr_input, mean, convergence) - if lambda_out < 0 or math.isnan(lambda_out): + if lambda_out is None or lambda_out < 0 or not math.isfinite(lambda_out): lambda_ret = None else: lambda_ret = lambda_out From 52e4048de62be01c7bbd6fc0d84de379206807a4 Mon Sep 17 00:00:00 2001 From: "R. Taylor Raborn." Date: Wed, 29 Apr 2026 14:42:00 -0400 Subject: [PATCH 41/41] Version bump to 2.2.0 --- src/yacht/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/yacht/utils.py b/src/yacht/utils.py index 9a86d6a2..325a27e8 100755 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -38,7 +38,7 @@ ksize: int = 31 # Note: hard-coding this for now # Set up global variables -__version__ = "2.0.1" +__version__ = "2.2.0" GITHUB_API_URL = "https://api.github.com/repos/KoslickiLab/YACHT/contents/demo/{path}" GITHUB_RAW_URL = "https://raw.githubusercontent.com/KoslickiLab/YACHT/main/demo/{path}" BASE_URL = "https://farm.cse.ucdavis.edu/~ctbrown/sourmash-db/"