diff --git a/.gitignore b/.gitignore index c9474082..ce9ee735 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,13 @@ gtdb-rs214-reps.k31_0.9995_pretrained/ # added by mahmudhera src/cpp/main.o .gitignore + +# added by rtraborn +tests_sra_data/ +*.fastq +*.fastq.gz +*.sig.zip +*_temp/ +*_intermediate_files/ +*.srademo/query_data/*.zip +*.sra diff --git a/README.md b/README.md index a15e3d6a..de4bcf3a 100644 --- a/README.md +++ b/README.md @@ -341,22 +341,33 @@ The `--min_coverage_list` parameter dictates a list of `min_coverage` which indi The output file will be an EXCEL file; column descriptions can be found [here](docs/column_descriptions.csv). The most important are the following: +**Core Detection Columns:** * `organism_name`: The name of the organism -* `in_sample_est`: A boolean value either False or True: if False, there was not enough evidence to claim this organism is present in the sample. +* `in_sample_est`: A boolean value either False or True: if False, there was not enough evidence to claim this organism is present in the sample. * `p_vals`: Probability of observing this or more extreme result at the given ANI threshold, assuming the null hypothesis. - -Other interesting columns include: - * `num_exclusive_kmers_to_genome`: How many k-mers were found in this organism and no others * `num_matches`: How many k-mers were found in this organism and the sample * `acceptance_threshold_*`: How many k-mers must be found in this organism to be considered "present" at the given ANI threshold. Hence, `in_sample_est` is True if `num_matches` >= `acceptance_threshold_*` (adjusting by coverage if desired). * `alt_confidence_mut_rate_*`: What the mutation rate (1-ANI) would need to be to get your false positive to match the false negative rate of 1-`significance` (adjusting by coverage if desired). +**Coverage Statistics (described in sylph: Shaw & Yu, 2024 | New in superyacht):** +* `naive_ani`: Simple ANI estimate from k-mer containment (0-1, multiply by 100 for percentage) +* `final_est_ani`: Coverage-adjusted ANI estimate (more accurate than naive_ani) +* `final_est_cov`: Expected coverage (lambda parameter) - average sequencing depth for this organism +* `mean_cov` / `median_cov`: Coverage distribution statistics +* `lambda_status`: Coverage calculation method (LAMBDA/HIGH/LOW) + +**Relative Abundance (Winner Map):** +* `rel_abund`: Relative abundance estimate (0-1, normalized across all organisms in sample) +* `kmers_lost`: Number of k-mers reassigned to organisms with higher ANI + +**ANI Filtering:** Organisms with `final_est_ani < 0.90` (90% ANI) are automatically filtered from results to remove low-quality matches. +
### 4. Convert YACHT result to other popular output formats (yacht convert) -When we get the EXCEL result file from run_YACHT.py, you can run `yacht convert` to covert the YACHT result to other popular output formats (Currently, only `cami`, `biom`, `graphplan` are supported). +When you get the EXCEL result file from run_YACHT.py, you can run `yacht convert` to covert the YACHT result to other popular output formats (Currently, only `cami`, `biom`, `graphplan` are supported). __Note__: Before you run `yacht convert`, you need to prepare a TSV file `genome_to_taxid.tsv` containing two columns: genome ID (genome_id) and its corresponding taxid (taxid). An example can be found [here](demo/toy_genome_to_taxid.tsv). You need to prepare it according to the reference database genomes you used. @@ -374,8 +385,9 @@ yacht convert --yacht_output 'result.xlsx' --sheet_name 'min_coverage0.01' --gen | --genome_to_taxid | the path to the location of `genome_to_taxid.tsv` you prepared | | --mode | specify to which output format you want to convert (e.g., 'cami', 'biom', 'graphplan') | --sample_name | A random name you would like to show in header of the cami file. Default: Sample1.' | -| --outfile_prefix | the prefix of the output file. Default: result | +| --outfile_prefix | the prefix of the output file. Default: result | | --outdir | the path to output directory where the results will be genreated | +
diff --git a/src/yacht/cov_calc.py b/src/yacht/cov_calc.py new file mode 100644 index 00000000..89cb3900 --- /dev/null +++ b/src/yacht/cov_calc.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +import sourmash +import math +import numpy as np +import pandas as pd +from loguru import logger +from yacht.utils import ratio_lambda +from yacht.utils import mme_lambda +from yacht.utils import binary_search_lambda +from yacht.utils import mle_zip +from yacht.utils import bootstrap_interval +from yacht.utils import ani_from_lambda +from yacht.utils import _ContainArgs +from yacht.utils import AniResult +from yacht.utils import AdjustStatus, AdjustStatusType +from yacht.utils import SAMPLE_SIZE_CUTOFF, PVALUE_CUTOFF, MEDIAN_ANI_THRESHOLD, MAX_MEDIAN_FOR_MEAN_FINAL_EST, MIN_COUNT_THRESH, ksize +from scipy.stats import poisson, variation +from typing import Optional, Tuple, Dict, Any + +no_adj = False #consider updating this in future SUPERYACHT arguments +winner_map = None #skipping this step in this version +kmers_lost_count = None + +def cov_calc(sample_sig: sourmash.SourmashSignature, genome_sig: sourmash.SourmashSignature, convergence_nr: bool = True): + """ + Function that calculates lambda according to Shaw and Yu (2024) from two sourmash.Minshash files (resresenting the sample and the genome sketches). + """ + + myArgs = _ContainArgs() + + gn_hashes = genome_sig.minhash.hashes + gn_kmers_keys = genome_sig.minhash.hashes.keys() + gn_kmers_items = genome_sig.minhash.hashes.items() + gn_dict = dict(gn_kmers_items) + + sample_hashes_keys = sample_sig.minhash.hashes.keys() + samp_kmers_items = sample_sig.minhash.hashes.items() + samp_dict = dict(samp_kmers_items) + + covs = [] + contain_count = 0 + for kmer in gn_hashes: + if kmer in sample_hashes_keys: + if samp_dict[kmer] == 0: + continue + contain_count += 1 + covs.append(samp_dict[kmer]) + + if len(covs)==0: + return None + + naive_ani = math.pow(contain_count/len(gn_kmers_items), + 1/ksize) + + # Caps naive_ani at 1.0 to prevent biologically impossible ANIs + if naive_ani > 1.0: + logger.debug(f"Naive ANI {naive_ani:.6f} exceeds 1.0, capping at 1.0") + naive_ani = 1.0 + + covs.sort() + + if len(covs) == 0: + covs.append(0) + + len_ind = len(covs)//2 + median_cov = covs[len(covs)//2] + + pois_obj = poisson(median_cov) #creates a discrete frozen Poisson distribution object + cov_max = float('inf') + + if median_cov < 30: #if median coverage of 30 is not fulfilled + for i in range(len_ind,len(covs), 1): + cov = covs[i] + if pois_obj.cdf(cov) < PVALUE_CUTOFF: + cov_max = cov + else: + break + + # Check if cov_max remains inf (i.e. no valid maximum found) + if cov_max == float('inf'): + logger.debug( + f"No coverage outliers found for genome {genome_sig.name} " + f"(median_cov={median_cov}). Retaining all coverage values (cov_max=inf), " + ) + # cov_max remains float('inf'), so all covs pass the filter below; consistent with sylph behavior + + full_covs = [0] * (len(gn_hashes) - contain_count) + + for cov in covs: + if cov <= cov_max: + full_covs.append(cov) + var = variation(full_covs) + if var is not None: + logger.debug("VAR {} {}", var, genome_sig.name) + + mean_cov = sum(full_covs)//len(full_covs) + geq1_mean_cov = sum(full_covs)//len(covs) + if median_cov > MEDIAN_ANI_THRESHOLD: + return_lambda = AdjustStatus.high() + + else: + if (myArgs.ratio == True): + test_lambda = ratio_lambda(full_covs, MIN_COUNT_THRESH) + elif (myArgs.mme == True): + test_lambda = mme_lambda(full_covs) + elif (myArgs.bin == True): + test_lambda = binary_search_lambda(full_covs) + elif (myArgs.mle) == True: + test_lambda = mle_zip(full_covs, gn_kmers_items, convergence_nr) + else: + test_lambda = ratio_lambda(full_covs, MIN_COUNT_THRESH) + + if test_lambda is None: + return_lambda = AdjustStatus.low() + else: + return_lambda = AdjustStatus.lambda_value(test_lambda) + + match return_lambda.status: + + case AdjustStatusType.LAMBDA: + # executes if it is the Lambda case + final_est_cov = return_lambda.value + opt_lambda = final_est_cov + + case AdjustStatusType.HIGH: + # executes if it is high coverage case + if median_cov < MAX_MEDIAN_FOR_MEAN_FINAL_EST: + final_est_cov = geq1_mean_cov + else: + final_est_cov = median_cov + opt_lambda = final_est_cov + + case AdjustStatusType.LOW: + # if it is the "low" case + # final_est_cov logic is handled elsewhere, or use a default + opt_lambda = None + + # Adding a "wild-card" case, just to be safe + case _: + opt_lambda = None + + opt_est_ani = ani_from_lambda(opt_lambda, mean_cov, 31, full_covs) + + if opt_lambda == None or opt_est_ani == None or no_adj == True or return_lambda.status == AdjustStatusType.HIGH: + # Avoids adjusting for high-coverage (i.e. where median > MEDIAN_ANI_THRESHOLD) + final_est_ani = naive_ani + logger.debug(f"Using naive ANI (no adjustment needed): median_cov={median_cov}, naive_ani={naive_ani:.4f}") + else: + final_est_ani = opt_est_ani + + low_ani, high_ani, low_lambda, high_lambda= None, None, None, None + +#Conditional calculation of confidence intervals + + if myArgs.ci_int==True and opt_lambda is not None: + bootstrap = bootstrap_interval(full_covs, ksize, myArgs) + low_ani = bootstrap[0] + high_ani = bootstrap[1] + low_lambda = bootstrap[2] + high_lambda = bootstrap[3] + + if sample_sig.name: + seq_name = sample_sig.name + else: + seq_name = sample_sig.filename + +#This is code related to the winner_map situation + #kmers_lost = kmers_lost_count if winner_map is not None else None + + ani_result = AniResult( + naive_ani=naive_ani, + final_est_ani=final_est_ani, + final_est_cov=opt_lambda, + seq_name=seq_name, + gn_name=genome_sig.filename, + contig_name=genome_sig.name, + mean_cov=geq1_mean_cov, + median_cov=median_cov, + containment_index=(contain_count, len(gn_hashes)), + lambda_status=return_lambda, + ani_ci=(low_ani, high_ani), + lambda_ci=(low_lambda, high_lambda), + genome_sketch=genome_sig, + rel_abund=None, + seq_abund=None, + kmers_lost=None, + ) + + results = [ani_result] + + columns_ani = [ + "naive_ani", # the ani according to naive calculations + "final_est_ani", # final estimated ani + "final_est_cov", # The final estimated coverage + "seq_name", # the name of the sequence + "gn_name", # the name of the genome + "mean_cov", # the mean coverage observed + "median_cov", # the median coverage observed + "containment_index", #the containment index + "lambda_status", #lambda status + "ani_ci", #ani confidence interval + "lambda_ci", #lambda confidence interval + "genome_sketch", #genome Sourmash signature file + "rel_abund", #The taxonomic abundance observed (set to None for now) + "seq_abund", #the number of sequenes that match a given genome (set to None for now) + "kmers_lost" #the number of kmers that were reassigned during the tax triage step. (set to None for now) + ] + + cov_calc_df = pd.DataFrame(results, columns=columns_ani) + + return cov_calc_df + + + + + diff --git a/src/yacht/hypothesis_recovery_src.py b/src/yacht/hypothesis_recovery_src.py index 3420d36f..980c4c0a 100644 --- a/src/yacht/hypothesis_recovery_src.py +++ b/src/yacht/hypothesis_recovery_src.py @@ -10,14 +10,27 @@ from multiprocessing import Pool import sourmash import glob -from typing import List, Set, Tuple -from .utils import load_signature_with_ksize, decompress_all_sig_files +from typing import List, Set, Tuple, Dict +from .utils import ( + load_signature_with_ksize, + decompress_all_sig_files, + ratio_lambda, + ani_from_lambda, + MIN_COUNT_THRESH, + SAMPLE_SIZE_CUTOFF, +) # Configure Loguru logger from loguru import logger +from .cov_calc import cov_calc -warnings.filterwarnings("ignore") +""" +Hypothesis Recovery and Coverage Analysis Module +This module implements YACHT's core statistical framework for organism detection +in metagenomic samples, with integrated coverage modeling and abundance estimation. +""" +warnings.filterwarnings("ignore") logger.remove() logger.add( @@ -113,13 +126,58 @@ def get_organisms_with_nonzero_overlap( return multisearch_result["match_name"].to_list() +# Global variables for sharing state across worker processes +_worker_sample_sig = None +_worker_convergence_nr = True + +def _init_coverage_worker(sample_sig, convergence_nr): + """ + Initializer for worker processes to set up shared sample signature and convergence flag. + + :param sample_sig: Sample signature to be shared across all workers + :param convergence_nr: Whether to use convergence criterion in Newton-Raphson + """ + global _worker_sample_sig, _worker_convergence_nr + _worker_sample_sig = sample_sig + _worker_convergence_nr = convergence_nr + +def _calculate_coverage_worker(args): + """ + Worker function for parallel coverage calculation. + Uses global _worker_sample_sig instead of passing it as argument to avoid pickling overhead. + + :param args: Tuple of (md5sum, organism_name, path_to_genome_temp_dir, ksize) + :return: Result DataFrame from cov_calc with organism_name, or None if error occurs + """ + md5sum, organism_name, path_to_genome_temp_dir, ksize = args + try: + sig = load_signature_with_ksize( + os.path.join(path_to_genome_temp_dir, "signatures", md5sum + SIG_SUFFIX), + ksize, + ) + result_df = cov_calc(_worker_sample_sig, sig, _worker_convergence_nr) + if result_df is not None: + # Add organism_name to the result for proper matching + result_df['organism_name'] = organism_name + return result_df + except Exception as e: + logger.warning(f"Error calculating coverage for {organism_name} ({md5sum}): {e}") + return None + + def get_exclusive_hashes( manifest: pd.DataFrame, nontrivial_organism_names: List[str], sample_sig: sourmash.SourmashSignature, ksize: int, path_to_genome_temp_dir: str, -) -> Tuple[List[Tuple[int, int]], pd.DataFrame]: + num_threads: int = 16, + winner_takes_all: bool = False, + batch_size: int = 1000, + two_pass: bool = True, + convergence_nr: bool = True, + min_ani: float = 0.95, +) -> Tuple[List[Tuple[int, int]], pd.DataFrame, pd.DataFrame]: """ This function gets the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample, and then find how many are in the sample. @@ -137,13 +195,18 @@ def get_exclusive_hashes( :param sample_sig: the sample signature :param ksize: int (size of kmer) :param path_to_genome_temp_dir: string (path to the genome temporary directory generated by the training step) + :param num_threads: int (number of threads to use for parallel coverage calculation, default: 16) + :param winner_takes_all: bool (enable winner-takes-all k-mer reassignment and relative abundance, default: False) + :param batch_size: int (batch size for winner-takes-all processing, default: 1000) + :param two_pass: bool (use sylph's two-pass approach for more accurate k-mer reassignment, default: True) + Set to False for original one-pass behavior (for testing/comparison) :return: a list of tuples, each tuple contains the following information: 1. the number of unique hashes exclusive to the organism under consideration 2. the number of unique hashes exclusive to the organism under consideration that are in the sample a new manifest dataframe that only contains the organisms that have non-zero overlap with the sample + a dataframe with coverage statistics for each organism """ - def __find_exclusive_hashes( md5sum: str, path_to_temp_dir: str, @@ -155,7 +218,7 @@ def __find_exclusive_hashes( os.path.join(path_to_temp_dir, "signatures", md5sum + SIG_SUFFIX), ksize ) return {hash for hash in sig.minhash.hashes if hash in single_occurrence_hashes} - + # get manifest information for the organisms that have non-zero overlap with the sample sub_manifest = manifest.loc[ manifest["organism_name"].isin(nontrivial_organism_names), : @@ -177,6 +240,8 @@ def __find_exclusive_hashes( multiple_occurrence_hashes.add(hash) else: single_occurrence_hashes.add(hash) + + del multiple_occurrence_hashes # free up memory # Find hashes that are unique to each organism @@ -188,11 +253,47 @@ def __find_exclusive_hashes( md5sum, path_to_genome_temp_dir, ksize, single_occurrence_hashes ) ) + del single_occurrence_hashes # free up memory # Get sample hashes sample_hashes = set(sample_sig.minhash.hashes) + # Calculate coverage statistics for each organism (parallelized) + logger.info(f"Calculating coverage statistics using {num_threads} threads") + + # Calculates optimal chunk size for progress visibility + chunk_size = max(1, len(organism_md5sum_list) // (num_threads * 50)) + logger.info(f"Using chunk size of {chunk_size} for parallel processing") + + with Pool(processes=num_threads, initializer=_init_coverage_worker, initargs=(sample_sig, convergence_nr)) as pool: + # Prepare arguments for parallel processing (sample_sig shared via initializer to avoid pickling overhead) + # Include organism_name for proper matching (fixes misalignment bug from imap_unordered) + organism_name_list = sub_manifest["organism_name"].to_list() + args_list = [ + (md5sum, organism_name, path_to_genome_temp_dir, ksize) + for md5sum, organism_name in zip(organism_md5sum_list, organism_name_list) + ] + # Use imap_unordered for better performance (we are matching on organism_name) + stats_list = list( + tqdm( + pool.imap_unordered(_calculate_coverage_worker, args_list, chunksize=chunk_size), + total=len(organism_md5sum_list), + desc="Processing coverage per organism" + ) + ) + + # Filter out None results (from errors) + stats_list = [stats for stats in stats_list if stats is not None] + + if not stats_list: + raise ValueError("No coverage statistics were successfully calculated") + + # Concatenate all results - organism_name is already included in each DataFrame + final_stats_df = pd.concat(stats_list, ignore_index=True) + + del stats_list # free up memory + # Find hashes that are unique to each organism and in the sample logger.info("Finding hashes that are unique to each organism and in the sample") exclusive_hashes_info = [] @@ -203,7 +304,309 @@ def __find_exclusive_hashes( (len(exclusive_hashes), len(exclusive_hashes.intersection(sample_hashes))) ) - return exclusive_hashes_info, sub_manifest + # Conditionally run winner-takes-all (memory-intensive but provides relative abundance) + if winner_takes_all: + if two_pass: + # Two-pass approach (sylph-aligned): more accurate for closely related organisms + # Pass 1: Build initial winner map using original ANI estimates + logger.info("Pass 1: Building initial winner map for k-mer reassignment") + winner_map = build_winner_map(final_stats_df, path_to_genome_temp_dir, ksize, batch_size) + + # Recalculate ANI using only won k-mers (prepares for Pass 2) + final_stats_df = recalculate_ani_from_winner_map( + final_stats_df, winner_map, sample_sig, ksize, batch_size, min_ani=min_ani + ) + + # Pass 2: Rebuild winner map with refined ANI estimates + # Only include organisms that weren't eliminated in the recalculation + logger.info("Pass 2: Rebuilding winner map with refined ANI estimates") + winner_map = build_winner_map(final_stats_df, path_to_genome_temp_dir, ksize, batch_size) + else: + # One-pass approach (original): faster but less accurate for related organisms + logger.info("One-pass mode: Building winner map with initial ANI estimates") + winner_map = build_winner_map(final_stats_df, path_to_genome_temp_dir, ksize, batch_size) + + # Add placeholder columns for consistency + final_stats_df['reassignment_status'] = 'one_pass' + final_stats_df['original_ani'] = final_stats_df['final_est_ani'].copy() + + # Calculate relative abundance using final winner map + final_stats_df = estimate_relative_abundance(final_stats_df, winner_map, sample_sig, batch_size) + + # Free up memory by dropping genome_sketch column (no longer needed) + if 'genome_sketch' in final_stats_df.columns: + logger.info("Releasing genome signature objects to free memory") + final_stats_df.drop(columns=['genome_sketch'], inplace=True) + else: + logger.info("Skipping winner-takes-all (not enabled). Use --winner_takes_all to enable relative abundance estimation.") + # Add placeholder columns for consistency with winner-takes-all output + final_stats_df['rel_abund'] = float('nan') + final_stats_df['kmers_lost'] = 0 + final_stats_df['reassignment_status'] = 'not_applicable' + final_stats_df['original_ani'] = final_stats_df['final_est_ani'].copy() + + # Also drop genome_sketch to save memory + if 'genome_sketch' in final_stats_df.columns: + final_stats_df.drop(columns=['genome_sketch'], inplace=True) + + return exclusive_hashes_info, sub_manifest, final_stats_df + + +def build_winner_map( + final_stats_df: pd.DataFrame, + path_to_genome_temp_dir: str, + ksize: int, + batch_size: int = 1000 +) -> Dict[int, Tuple[float, str]]: + """ + Creates a "winner map" procedure that assigns k-mers to the organism with the highest ANI. + Uses memory-efficient batch processing. + + This implements the "winner takes all" strategy from sylph (Shaw and Yu, 2024) where + shared k-mers are assigned to the organism with the best ANI match, preventing double-counting. + Please note that this differs from the approach in sylph in that the procedure is run once, rather than + twice. + + :param final_stats_df: DataFrame with coverage statistics including organism_name, + final_est_ani, and genome_sketch columns + :param path_to_genome_temp_dir: Path to the directory containing genome signature files + :param ksize: k-mer size + :param batch_size: Number of organisms to process per batch (default: 1000) + :return: Dictionary mapping k-mer hash -> (ani, organism_name) + Only the organism with highest ANI "wins" each k-mer + """ + winner_map = {} + total_organisms = len(final_stats_df) + + logger.info(f"Building winner map for {total_organisms} organisms (batch size: {batch_size})") + + # Process in batches to control memory usage + for batch_start in range(0, total_organisms, batch_size): + batch_end = min(batch_start + batch_size, total_organisms) + batch_num = batch_start // batch_size + 1 + total_batches = (total_organisms + batch_size - 1) // batch_size + + for idx in tqdm( + range(batch_start, batch_end), + desc=f"Building winner map (batch {batch_num}/{total_batches})", + total=batch_end - batch_start + ): + row = final_stats_df.iloc[idx] + organism_name = row['organism_name'] + ani = row['final_est_ani'] + + # Skip organisms with no ANI estimate + if pd.isna(ani): + continue + + genome_sig = row['genome_sketch'] + + # For each k-mer, check if it should be reassigned + for kmer in genome_sig.minhash.hashes.keys(): + if kmer not in winner_map or ani > winner_map[kmer][0]: + winner_map[kmer] = (ani, organism_name) + + logger.info(f"Winner map built with {len(winner_map)} k-mers assigned") + + return winner_map + + +def recalculate_ani_from_winner_map( + final_stats_df: pd.DataFrame, + winner_map: Dict[int, Tuple[float, str]], + sample_sig: sourmash.SourmashSignature, + ksize: int, + batch_size: int = 1000, + min_ani: float = 0.95 +) -> pd.DataFrame: + """ + Recalculates ANI for each organism using only k-mers it 'won' in the winner map. + This implements Pass 2 of sylph's two-pass winner-takes-all approach. + + Organisms that lost all their k-mers are marked as 'eliminated' with rel_abund=0. + + :param final_stats_df: DataFrame with coverage statistics including organism_name, + final_est_ani, and genome_sketch columns + :param winner_map: k-mer to (ANI, organism_name) mapping from build_winner_map() + :param sample_sig: Sample signature with k-mer abundances + :param ksize: k-mer size for ANI calculation + :param batch_size: Number of organisms to process per batch (default: 1000) + :return: Updated DataFrame with recalculated ANI values and reassignment_status column + """ + logger.info("Recalculating ANI using only won k-mers (Pass 2 of two-pass approach)") + + sample_hashes = sample_sig.minhash.hashes + total_organisms = len(final_stats_df) + + # Initialize reassignment_status column + final_stats_df['reassignment_status'] = 'active' + + # Store original ANI for reference + final_stats_df['original_ani'] = final_stats_df['final_est_ani'].copy() + + eliminated_count = 0 + + for batch_start in range(0, total_organisms, batch_size): + batch_end = min(batch_start + batch_size, total_organisms) + batch_num = batch_start // batch_size + 1 + total_batches = (total_organisms + batch_size - 1) // batch_size + + for idx in tqdm( + range(batch_start, batch_end), + desc=f"Recalculating ANI (batch {batch_num}/{total_batches})", + total=batch_end - batch_start + ): + row = final_stats_df.iloc[idx] + organism_name = row['organism_name'] + genome_sig = row['genome_sketch'] + + # Count k-mers won by this organism and build coverage list + won_kmers_in_sample = [] + total_won_kmers = 0 + + for kmer in genome_sig.minhash.hashes.keys(): + if kmer in winner_map and winner_map[kmer][1] == organism_name: + total_won_kmers += 1 + if kmer in sample_hashes and sample_hashes[kmer] > 0: + won_kmers_in_sample.append(sample_hashes[kmer]) + + # Handle organisms that lost all k-mers + if total_won_kmers == 0: + final_stats_df.at[idx, 'reassignment_status'] = 'eliminated' + final_stats_df.at[idx, 'final_est_ani'] = float('nan') + eliminated_count += 1 + continue + + # Check if we have enough data for lambda estimation + if len(won_kmers_in_sample) < SAMPLE_SIZE_CUTOFF: + # Not enough won k-mers for reliable lambda re-estimation, but don't eliminate. + # Compute naive ANI from won k-mers; only update final_est_ani if the naive + # estimate is above threshold — otherwise retains the pre-WTA estimate. + if total_won_kmers > 0: + naive_won_ani = (len(won_kmers_in_sample) / total_won_kmers) ** (1 / ksize) + if naive_won_ani >= min_ani: + final_stats_df.at[idx, 'final_est_ani'] = naive_won_ani + # else: retain original pre-WTA final_est_ani + final_stats_df.at[idx, 'reassignment_status'] = 'lambda_failed' + continue + + # Build full_cov array (zeros for won k-mers not in sample + coverages for those in sample) + num_zeros = total_won_kmers - len(won_kmers_in_sample) + full_cov = [0] * num_zeros + won_kmers_in_sample + + # Recalculate lambda using ratio method + new_lambda = ratio_lambda(full_cov, MIN_COUNT_THRESH) + + if new_lambda is None: + # Lambda estimation failed - keep original ANI but mark status + final_stats_df.at[idx, 'reassignment_status'] = 'lambda_failed' + # Keep original ANI value (don't modify final_est_ani) + continue + + # Recalculate ANI from new lambda + mean_cov = sum(full_cov) / len(full_cov) if full_cov else 0 + new_ani = ani_from_lambda(new_lambda, mean_cov, ksize, full_cov) + + if new_ani is not None: + final_stats_df.at[idx, 'final_est_ani'] = new_ani + else: + # ANI calculation failed - mark status but keep original + final_stats_df.at[idx, 'reassignment_status'] = 'ani_failed' + + logger.info(f"ANI recalculation complete: {eliminated_count} organisms eliminated, " + f"{total_organisms - eliminated_count} remain active") + + return final_stats_df + + +def estimate_relative_abundance( + final_stats_df: pd.DataFrame, + winner_map: Dict[int, Tuple[float, str]], + sample_sig: sourmash.SourmashSignature, + batch_size: int = 1000 +) -> pd.DataFrame: + """ + Estimates the relative abundance of each organism based on winner_map k-mer assignments. + Uses memory-efficient batch processing. + + After winner_map assigns shared k-mers to organisms with highest ANI, this calculates: + 1. How many k-mers each organism "lost" to others (kmers_lost) + 2. Total coverage from k-mers "won" by each organism (used for relative abundance) + 3. Relative abundance normalized across all organisms + + Organisms marked as 'eliminated' in reassignment_status are skipped (rel_abund = 0). + + :param final_stats_df: DataFrame with coverage statistics + :param winner_map: k-mer to (ANI, organism_name) mapping from build_winner_map() + :param sample_sig: Sample signature with k-mer abundances + :param batch_size: Number of organisms to process per batch (default: 1000) + :return: Updated DataFrame with rel_abund and kmers_lost columns populated + """ + logger.info("Estimating relative abundance using winner map (memory-efficient mode)") + + # Initialize columns + final_stats_df['kmers_lost'] = 0 + final_stats_df['rel_abund'] = 0.0 + + sample_hashes = sample_sig.minhash.hashes + total_organisms = len(final_stats_df) + + # Check if reassignment_status column exists (from two-pass approach) + has_reassignment_status = 'reassignment_status' in final_stats_df.columns + + # Process in batches to reduce memory usage + for batch_start in range(0, total_organisms, batch_size): + batch_end = min(batch_start + batch_size, total_organisms) + batch_num = batch_start // batch_size + 1 + total_batches = (total_organisms + batch_size - 1) // batch_size + + for idx in tqdm( + range(batch_start, batch_end), + desc=f"Calculating relative abundance (batch {batch_num}/{total_batches})", + total=batch_end - batch_start + ): + row = final_stats_df.iloc[idx] + organism_name = row['organism_name'] + + # Skip eliminated organisms + if has_reassignment_status and row['reassignment_status'] == 'eliminated': + continue + + genome_sig = row['genome_sketch'] + + kmers_lost_count = 0 + total_coverage = 0.0 + + # Check each k-mer in this genome + for kmer in genome_sig.minhash.hashes.keys(): + # Check if this organism "won" this k-mer + if kmer in winner_map: + winner_organism = winner_map[kmer][1] + + if winner_organism != organism_name: + # This k-mer was reassigned to another organism + kmers_lost_count += 1 + else: + # This organism won this k-mer - count its coverage + if kmer in sample_hashes: + total_coverage += sample_hashes[kmer] + + final_stats_df.at[idx, 'kmers_lost'] = kmers_lost_count + + # Normalizes coverage by genome size to avoid bias toward larger genomes + # (genome size estimated as num_kmers * the scale factor from sourmash) + genome_size = len(genome_sig.minhash.hashes) * genome_sig.minhash.scaled + final_stats_df.at[idx, 'rel_abund'] = total_coverage / genome_size if genome_size > 0 else 0.0 + + # Normalize relative abundance to sum to 1.0 across all organisms + total_abundance = final_stats_df['rel_abund'].sum() + if total_abundance > 0: + final_stats_df['rel_abund'] = final_stats_df['rel_abund'] / total_abundance + logger.info(f"Relative abundance normalized (total coverage: {total_abundance:.2f})") + else: + logger.warning("No coverage found for relative abundance calculation") + + return final_stats_df def get_alt_mut_rate( @@ -291,9 +694,7 @@ def single_hyp_test( in_sample_est = (num_matches >= acceptance_threshold_with_coverage) and ( num_matches != 0 ) - # return in_sample_est, p_val, num_exclusive_kmers, num_exclusive_kmers_coverage, num_matches, \ - # acceptance_threshold_wo_coverage, acceptance_threshold_with_coverage, actual_confidence_wo_coverage, \ - # actual_confidence_with_coverage, alt_confidence_mut_rate, alt_confidence_mut_rate_with_coverage + return ( in_sample_est, p_val, @@ -316,6 +717,12 @@ def hypothesis_recovery( significance: float = 0.99, ani_thresh: float = 0.95, num_threads: int = 16, + winner_takes_all: bool = False, + batch_size: int = 1000, + two_pass: bool = True, + calculate_coverage: bool = False, + convergence_nr: bool = True, + min_ani: float = 0.95, ): """ Go through each of the organisms that have non-zero overlap with the sample and perform a hypothesis test for the @@ -339,6 +746,8 @@ def hypothesis_recovery( :param significance: significance level for the hypothesis test :param ani_thresh: threshold for ANI (i.e. how similar do the genomes need to be in order to be considered the same) :param num_threads: number of threads to use for parallelization + :param two_pass: bool (use sylph's two-pass approach for winner-takes-all, default: True) + :param calculate_coverage: bool (use calculated coverage per organism instead of min_coverage_list, default: False) :return: a list of pandas dataframe with the results of the hypothesis tests based on different min_coverage values """ @@ -369,11 +778,13 @@ def hypothesis_recovery( ) # Get the unique hashes exclusive to each of the organisms that have non-zero overlap with the sample - exclusive_hashes_info, manifest = get_exclusive_hashes( - manifest, nontrivial_organism_names, sample_sig, ksize, path_to_genome_temp_dir + exclusive_hashes_info, manifest, final_stats_df = get_exclusive_hashes( + manifest, nontrivial_organism_names, sample_sig, ksize, path_to_genome_temp_dir, + num_threads, winner_takes_all, batch_size, two_pass, convergence_nr, min_ani ) # Set up the results dataframe columns + # n.b. that the output of cov_calc is not being incorporated here; instead it's being returned separately, as a pandas dataframe. given_columns = [ "in_sample_est", # Main output: Boolean indicating whether genome is present in sample "p_vals", # Probability of observing this or more extreme result at ANI threshold. @@ -392,8 +803,59 @@ def hypothesis_recovery( # Using multiprocessing.Pool to parallelize the execution manifest_list = [] - for min_coverage in tqdm(min_coverage_list, desc="Computing hypothesis recovery"): - logger.info(f"Computing hypothesis recovery for min_coverage={min_coverage}") + + fallback_coverage = min(min_coverage_list) if min_coverage_list else 0.1 + + if calculate_coverage: + # CALCULATE_COVERAGE MODE: Use calculated coverage (final_est_cov) per organism + logger.info("Using calculate_coverage mode: applying calculated coverage per organism") + + # Build a mapping from organism_name to calculated coverage (final_est_cov) + # Coverage depth (lambda) is converted to detection fraction using Poisson probability: + # P(a k-mer is detected) = 1 - exp(-lambda) + # This gives the expected fraction of k-mers that will be observed at least once. + # Compute sample-wide median lambda for fallback coverage + valid_lambdas = [ + row['final_est_cov'] for _, row in final_stats_df.iterrows() + if pd.notna(row['final_est_cov']) and row['final_est_cov'] > 0 + ] + if valid_lambdas: + median_lambda = np.median(valid_lambdas) + fallback_coverage = 1.0 - np.exp(-median_lambda) + logger.info(f"Sample-wide median lambda: {median_lambda:.4f}, " + f"fallback detection fraction: {fallback_coverage:.4f}") + else: + fallback_coverage = 0.1 + logger.warning("No valid lambda estimates in sample; using fallback coverage 0.1") + + coverage_map = {} + for _, row in final_stats_df.iterrows(): + org_name = row['organism_name'] + cov_val = row['final_est_cov'] + median_cov = row['median_cov'] + + if pd.notna(cov_val) and cov_val > 0: + # Primary choicev: use final_est_cov (lambda) with Poisson detection probability + # Convert depth to the expected fraction of k-mers detected + coverage_map[org_name] = 1.0 - np.exp(-cov_val) + elif pd.notna(median_cov) and median_cov > 0: + # Fallback: use sample-wide median lambda rather than per-organism + # median_cov, since median_cov at low coverage (e.g. 1x) gives an + # artificially strict detection fraction of (1 - e^-1) + coverage_map[org_name] = fallback_coverage + logger.warning(f"No valid lambda for {org_name}, using fallback_coverage={fallback_coverage:.4f}") + else: + # The last resort + coverage_map[org_name] = fallback_coverage + logger.warning(f"No valid coverage data for {org_name}, using fallback_coverage={fallback_coverage:.4f}") + + # Get organism names in manifest order (aligned with exclusive_hashes_info) + organism_names = manifest["organism_name"].to_list() + + # Build per-organism coverage list aligned with exclusive_hashes_info + per_organism_coverage = [coverage_map.get(name, 1.0) for name in organism_names] + + # Run hypothesis test with per-organism coverage with Pool(processes=num_threads) as p: params = ( ( @@ -401,17 +863,89 @@ def hypothesis_recovery( ksize, significance, ani_thresh, - min_coverage, + per_organism_coverage[i], # Per-organism coverage ) for i in range(len(exclusive_hashes_info)) ) results = p.starmap(single_hyp_test, params) + logger.info("Finished computing hypothesis recovery with calculate_coverage") - # Create a pandas dataframe to store the results + # Create results DataFrame results = pd.DataFrame(results, columns=given_columns) - # combine the results with the manifest - manifest["min_coverage"] = min_coverage + # Add per-organism coverage to manifest (replaces the fixed min_coverage column) + manifest["min_coverage"] = per_organism_coverage manifest_list.append(pd.concat([manifest, results], axis=1)) + else: + # ORIGINAL MODE: Loop over user-supplied min_coverage_list + for min_coverage in tqdm(min_coverage_list, desc="Computing hypothesis recovery"): + logger.info(f"Computing hypothesis recovery for min_coverage={min_coverage}") + with Pool(processes=num_threads) as p: + params = ( + ( + exclusive_hashes_info[i], + ksize, + significance, + ani_thresh, + min_coverage, + ) + for i in range(len(exclusive_hashes_info)) + ) + results = p.starmap(single_hyp_test, params) + logger.info(f"Finished computing all results for min_coverage value: {min_coverage}") + + # Create a pandas dataframe to store the results + results = pd.DataFrame(results, columns=given_columns) + + # combine the results with the manifest + manifest["min_coverage"] = min_coverage + manifest_list.append(pd.concat([manifest, results], axis=1)) + + # Merge coverage statistics into each manifest DataFrame + # Select key coverage columns to include in output (including winner_map results) + coverage_cols = [ + 'organism_name', + 'naive_ani', + 'final_est_ani', + 'final_est_cov', + 'mean_cov', + 'median_cov', + 'lambda_status', + 'ani_ci', + 'lambda_ci', + 'rel_abund', # Relative abundance from winner_map + 'kmers_lost' # K-mers reassigned to other organisms + ] + coverage_stats = final_stats_df[coverage_cols].copy() + + # Merge coverage stats into each manifest in the list + for i in range(len(manifest_list)): + manifest_list[i] = manifest_list[i].merge( + coverage_stats, + on='organism_name', + how='left' # Keep all organisms, even those without coverage stats + ) + + # ANI threshold filtering + logger.info(f"Filtering organisms with final_est_ani < {min_ani} ({min_ani*100:.0f}% ANI)") + for i in range(len(manifest_list)): + initial_count = len(manifest_list[i]) + # Keep organisms with ANI >= threshold OR organisms with no ANI estimate (NaN) + manifest_list[i] = manifest_list[i][ + (manifest_list[i]['final_est_ani'] >= min_ani) | + (manifest_list[i]['final_est_ani'].isna()) + ].reset_index(drop=True) + filtered_count = initial_count - len(manifest_list[i]) + if filtered_count > 0: + logger.info(f" Filtered {filtered_count} organisms below ANI threshold from min_coverage={manifest_list[i]['min_coverage'].iloc[0] if len(manifest_list[i]) > 0 else 'N/A'} results") + #post_filtered_df['rel_abund'] = post_filtered_df['rel_abund'] / total_abundance + # Re-normalizing, regardless of filter results (i.e. filtered_count) + post_filtered_df = manifest_list[i] + total_abundance = post_filtered_df['rel_abund'].sum() + if total_abundance > 0: + manifest_list[i].loc[:, 'rel_abund'] = manifest_list[i]['rel_abund'] / total_abundance + logger.info(f"Relative abundance normalized (total coverage: {total_abundance:.2f}.)") + else: + logger.warning(f"No relative abundance was done after ANI filtering.") return manifest_list diff --git a/src/yacht/run_YACHT.py b/src/yacht/run_YACHT.py index dc5d5b02..0bd9bcc1 100644 --- a/src/yacht/run_YACHT.py +++ b/src/yacht/run_YACHT.py @@ -45,6 +45,29 @@ def add_arguments(parser): required=False, default=16, ) + parser.add_argument( + "--winner_takes_all", + action="store_true", + help="Enables winner-takes-all k-mer reassignment and relative abundance estimation. " + "Shared k-mers are assigned to the taxon with the highest ANI. " + "Uses more memory-efficient batch processing. ", + default=False, + ) + parser.add_argument( + "--batch_size", + type=int, + help="Batch size for winner-takes-all processing (lower size uses less memory). " + "Only used with --winner_takes_all. Default: 1000", + default=1000, + ) + parser.add_argument( + "--no_two_pass", + action="store_true", + help="Disable sylph's two-pass winner-takes-all approach. Uses original one-pass method. " + "Two-pass (default) is more accurate for closely related organisms but slower. " + "Only applies when --winner_takes_all is enabled.", + default=False, + ) parser.add_argument( "--keep_raw", action="store_true", help="Keep raw results in output file." ) @@ -59,9 +82,38 @@ def add_arguments(parser): type=float, help="A list of percentages of unique k-mers covered by reads in the sample. " "Each value should be between 0 and 1, with 0 being the most sensitive (and least " - "precise) and 1 being the most precise (and least sensitive).", + "precise) and 1 being the most precise (and least sensitive). " + "Default: [1, 0.5, 0.1, 0.05, 0.01]", + required=False, + default=None, + ) + parser.add_argument( + "--calculate_coverage", + action="store_true", + help="Automatically calculate coverage for each organism using the sylph coverage model " + "(lambda estimation) instead of using user-supplied min_coverage_list values. " + "When enabled, produces a single output sheet with per-organism calculated coverage. " + "Cannot be used together with --min_coverage_list.", + default=False, + ) + parser.add_argument( + "--convergence_nr", + action="store_true", + help="Turn on the convergence criterion in the Newton-Raphson lambda estimator, " + "terminating when the update falls below " + f"LAMBDA_EPSILON ({utils.LAMBDA_EPSILON}). " + "By default, all 1000 iterations run without terminating, " + "matching the original sylph behavior. ", + default=False, + ) + parser.add_argument( + "--min_ani", + type=float, + help="Minimum ANI threshold for retaining organisms in results. " + "Organisms whose final estimated ANI falls below this value are filtered out. " + "Default: 0.95 (species-level boundary).", required=False, - default=[1, 0.5, 0.1, 0.05, 0.01], + default=0.95, ) parser.add_argument( "--out", @@ -77,13 +129,44 @@ def main(args): sample_file = str(Path(args.sample_file).absolute()) # location of sample.sig file significance = args.significance # Minimum probability of individual true negative. num_threads = args.num_threads # Number of threads to use for parallelization. + winner_takes_all = args.winner_takes_all # Enable winner-takes-all k-mer reassignment + batch_size = args.batch_size # Batch size for winner-takes-all processing + two_pass = not args.no_two_pass # Use sylph's two-pass approach (default: True) keep_raw = args.keep_raw # Keep raw results in output file. show_all = args.show_all # Show all organisms (no matter if present) in output file. - min_coverage_list = args.min_coverage_list # a list of percentages of unique k-mers covered by reads in the sample. + calculate_coverage = args.calculate_coverage # Use calculated coverage instead of user-supplied list + convergence_nr = args.convergence_nr # Use convergence criterion in Newton-Raphson (default: False) + min_ani = args.min_ani # Minimum ANI threshold for filtering organisms + + if not (0.90 <= min_ani <= 1): + raise ValueError( + f"--min_ani value {min_ani} must be between 0.90 (genus-level) and 1 (both inclusive)." + ) out = str(Path(args.out).absolute()) # full path to output excel file + + # Validate mutual exclusivity of --calculate_coverage and --min_coverage_list + if calculate_coverage and args.min_coverage_list is not None: + raise ValueError( + "--calculate_coverage and --min_coverage_list cannot be used together. " + "Use --calculate_coverage for automatic per-organism coverage calculation, " + "or --min_coverage_list for manual coverage thresholds." + ) + + # Set default min_coverage_list if not provided and not using calculate_coverage + if args.min_coverage_list is None: + min_coverage_list = [1, 0.5, 0.1, 0.05, 0.01] # Default values + else: + min_coverage_list = args.min_coverage_list outdir = os.path.dirname(out) # path to output directory out_filename = os.path.basename(out) # output filename + # Validate that batch_size is only used with winner_takes_all + if batch_size != 1000 and not winner_takes_all: + raise ValueError( + "--batch_size can only be used with --winner_takes_all. " + "Either remove --batch_size or add --winner_takes_all." + ) + # check if the output filename is valid if os.path.splitext(out_filename)[1] != ".xlsx": raise ValueError( @@ -192,6 +275,12 @@ def main(args): significance, ani_thresh, num_threads, + winner_takes_all, + batch_size, + two_pass, + calculate_coverage, + convergence_nr, + min_ani, ) # remove unnecessary columns @@ -211,29 +300,45 @@ def main(args): logger.info(f"Saving results to {outdir}.") # save the results with different min_coverage with pd.ExcelWriter(out, engine="openpyxl", mode="w") as writer: - # save the raw results (i.e., min_coverage=1.0) - if keep_raw: - temp_mainifest = manifest_list[0].copy() - temp_mainifest.rename( - columns={ - "acceptance_threshold_with_coverage": "acceptance_threshold_wo_coverage", - "actual_confidence_with_coverage": "actual_confidence_wo_coverage", - "alt_confidence_mut_rate_with_coverage": "alt_confidence_mut_rate_wo_coverage", - }, - inplace=True, - ) - temp_mainifest.to_excel(writer, sheet_name="raw_result", index=False) - # save the results with different min_coverage given by the user - if not has_raw: - min_coverage_list = min_coverage_list[1:] - manifest_list = manifest_list[1:] - - for min_coverage, temp_mainifest in zip(min_coverage_list, manifest_list): + if calculate_coverage: + # Calculate coverage mode: single sheet with per-organism calculated coverage + temp_manifest = manifest_list[0].copy() if not show_all: - temp_mainifest = temp_mainifest[temp_mainifest["in_sample_est"] == True] - temp_mainifest.to_excel( - writer, sheet_name=f"min_coverage{min_coverage}", index=False - ) + temp_manifest = temp_manifest[temp_manifest["in_sample_est"] == True] + # Adding re-normalization here also + total_abundance = temp_manifest['rel_abund'].sum() + if total_abundance > 0: + temp_manifest.loc[:, 'rel_abund'] = temp_manifest['rel_abund'] / total_abundance + temp_manifest.to_excel(writer, sheet_name="calculated_coverage", index=False) + else: + # Original behavior: multiple sheets based on min_coverage_list + # save the raw results (i.e., min_coverage=1.0) + if keep_raw: + temp_manifest = manifest_list[0].copy() + temp_manifest.rename( + columns={ + "acceptance_threshold_with_coverage": "acceptance_threshold_wo_coverage", + "actual_confidence_with_coverage": "actual_confidence_wo_coverage", + "alt_confidence_mut_rate_with_coverage": "alt_confidence_mut_rate_wo_coverage", + }, + inplace=True, + ) + temp_manifest.to_excel(writer, sheet_name="raw_result", index=False) + # save the results with different min_coverage given by the user + if not has_raw: + min_coverage_list = min_coverage_list[1:] + manifest_list = manifest_list[1:] + + for min_coverage, temp_manifest in zip(min_coverage_list, manifest_list): + if not show_all: + temp_manifest = temp_manifest[temp_manifest["in_sample_est"] == True] + #adding renormilization for the original behavior + total_abundance = temp_manifest['rel_abund'].sum() + if total_abundance > 0: + temp_manifest.loc[:, 'rel_abund'] = temp_manifest['rel_abund'] / total_abundance + temp_manifest.to_excel( + writer, sheet_name=f"min_coverage{min_coverage}", index=False + ) if __name__ == "__main__": diff --git a/src/yacht/standardize_yacht_output.py b/src/yacht/standardize_yacht_output.py index 8fdac045..26661af6 100644 --- a/src/yacht/standardize_yacht_output.py +++ b/src/yacht/standardize_yacht_output.py @@ -32,8 +32,21 @@ def add_arguments(parser): parser.add_argument( "--sheet_name", type=str, - help="The sheet name of the YACHT output excel file.", - required=True, + help="The sheet name of the YACHT output excel file. " + "Required unless --single_sheet is used.", + required=False, + default=None, + ) + parser.add_argument( + "--single_sheet", + action="store_true", + default=False, + help="Automatically selects the 'calculated_coverage' sheet produced when " + "yacht run is invoked with --calculate_coverage. For full relative " + "abundance percentages, yacht run must also have been called with " + "--winner_takes_all; otherwise percentages revert to count-based. " + "Mutually exclusive with --sheet_name." + ) parser.add_argument( "--genome_to_taxid", @@ -70,13 +83,24 @@ def add_arguments(parser): def main(args): yacht_output = args.yacht_output - sheet_name = args.sheet_name genome_to_taxid = args.genome_to_taxid mode = args.mode sample_name = args.sample_name outfile_prefix = args.outfile_prefix outdir = args.outdir + # resolves sheet name from --sheet_name or --single_sheet + if args.single_sheet and args.sheet_name: + logger.error("--single_sheet and --sheet_name are mutually exclusive.") + raise ValueError + elif args.single_sheet: + sheet_name = "calculated_coverage" + elif args.sheet_name: + sheet_name = args.sheet_name + else: + logger.error("One of either --sheet_name or --single_sheet must be provided.") + raise ValueError + # check if the yacht output file exists if not os.path.exists(yacht_output): logger.error(f"{yacht_output} does not exist.") @@ -92,9 +116,20 @@ def main(args): os.makedirs(outdir) # load the yacht output - yacht_output_df = pd.read_excel( - yacht_output, sheet_name=sheet_name, engine="openpyxl" - ) + try: + yacht_output_df = pd.read_excel( + yacht_output, sheet_name=sheet_name, engine="openpyxl" + ) + except ValueError as e: + if args.single_sheet: + logger.error( + f"Sheet 'calculated_coverage' not found in {yacht_output}. " + "This sheet is only present when YACHT was run with --calculate_coverage. " + "If you used a min_coverage list instead, use --sheet_name to specify the sheet." + ) + else: + logger.error(f"Sheet '{sheet_name}' not found in {yacht_output}: {e}") + sys.exit(1) # converet the first column to string yacht_output_df["organism_name"] = yacht_output_df["organism_name"].astype(str) @@ -247,7 +282,7 @@ def __to_cami(self, sample_name): ## select the organisms that YACHT considers to present in the sample yacht_res_df = self.yacht_output.copy() - organism_id_list = yacht_res_df["organism_name"].tolist() + organism_id_list = yacht_res_df["organism_name"].str.split().str[0].tolist() if len(organism_id_list) == 0: logger.error("No organism is detected by YACHT.") @@ -258,6 +293,40 @@ def __to_cami(self, sample_name): "genome_id in @organism_id_list" ).reset_index(drop=True) + ## Determines per-organism weights for percentage calculation. + use_rel_abund = ( + "rel_abund" in self.yacht_output.columns + and self.yacht_output["rel_abund"].notna().any() + ) + if "rel_abund" in self.yacht_output.columns and not use_rel_abund: + logger.warning( + "rel_abund column is present but empty — YACHT run was likely executed without " + "--winner_takes_all." + ) + if use_rel_abund: + genome_id_set = set(selected_organism_metadata_df["genome_id"]) + org_weights = ( + self.yacht_output + .assign(_key=self.yacht_output["organism_name"].str.split().str[0]) + .set_index("_key")["rel_abund"] + .reindex(sorted(genome_id_set), fill_value=0.0) + .fillna(0.0) + ) + total_weight = org_weights.sum() + if total_weight == 0: + logger.warning( + "Sum of rel_abund is zero; reverting to count-based percentages." + ) + use_rel_abund = False + else: + weight_lookup = org_weights.to_dict() + elif not use_rel_abund: + logger.warning( + "Falling back to count-based percentages." + ) + total_weight = float(len(selected_organism_metadata_df)) + weight_lookup = None + ## Summarize the results summary_dict = {} for row in selected_organism_metadata_df.to_numpy(): @@ -269,6 +338,7 @@ def __to_cami(self, sample_name): taxid_list = list(np.array(row[3].split("|"))[select_index]) lineage_list = list(np.array(row[4].split("|"))[select_index]) rank_list = list(np.array(row[5].split("|"))[select_index]) + weight = weight_lookup.get(row[0], 0.0) if use_rel_abund else 1.0 current_lineage = "" current_taxid = "" for index, (taxid, rank, lineage) in enumerate( @@ -285,15 +355,15 @@ def __to_cami(self, sample_name): "RANK": rank, "TAXPATH": current_taxid, "TAXPATHSN": current_lineage, - "count": 1, + "weight": weight, } else: - summary_dict[taxid]["count"] += 1 + summary_dict[taxid]["weight"] += weight # calculate percentage for taxid in summary_dict: summary_dict[taxid]["PERCENTAGE"] = ( - summary_dict[taxid]["count"] / len(selected_organism_metadata_df) * 100 + summary_dict[taxid]["weight"] / total_weight * 100 ) ## sort by rank in allowable rank list @@ -303,8 +373,9 @@ def __to_cami(self, sample_name): .rename(columns={"index": "TAXID"}) ) res_df = [summary_df.query(f'RANK == "{rank}"') for rank in self.allowable_rank] - res_df = pd.concat(res_df).drop(columns=["count"]).reset_index(drop=True) + res_df = pd.concat(res_df).drop(columns=["weight"]).reset_index(drop=True) res_df.columns = ["@@TAXID", "RANK", "TAXPATH", "TAXPATHSN", "PERCENTAGE"] + res_df["PERCENTAGE"] = res_df["PERCENTAGE"].round(6) ## output summary results if len(res_df) == 0: diff --git a/src/yacht/utils.py b/src/yacht/utils.py old mode 100644 new mode 100755 index beaf0988..325a27e8 --- a/src/yacht/utils.py +++ b/src/yacht/utils.py @@ -1,15 +1,23 @@ import os import sys import numpy as np -import sourmash +from typing import List, Optional +from collections import Counter +from sourmash import load_one_signature +from collections import defaultdict from tqdm import tqdm import pandas as pd from multiprocessing import Pool from loguru import logger -from typing import Optional, List, Set, Dict, Tuple +from typing import Optional, List, Set, Dict, Tuple, Any import shutil import gzip +import math +import random +import sourmash +from dataclasses import dataclass from glob import glob +from scipy.special import gamma # Configure Loguru logger logger.remove() @@ -17,16 +25,91 @@ sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} - {level} - {message}", level="INFO" ) -# Set up contants +# Set up constants COL_NOT_FOUND_ERROR = "Column not found: {}" FILE_LOCATION = os.path.dirname(os.path.realpath(__file__)) +# Sylph (Shaw and Yu, 2024) related constants +SAMPLE_SIZE_CUTOFF: int = 25 +PVALUE_CUTOFF: float = 0.9999999999 +MEDIAN_ANI_THRESHOLD: float = 3.00 +MAX_MEDIAN_FOR_MEAN_FINAL_EST: float = 15.0 +MIN_COUNT_THRESH: int = 3 +LAMBDA_EPSILON: float = 1e-10 # Minimum lambda value to avoid dividing by zero +ksize: int = 31 # Note: hard-coding this for now # Set up global variables -__version__ = "1.3.2" +__version__ = "2.2.0" GITHUB_API_URL = "https://api.github.com/repos/KoslickiLab/YACHT/contents/demo/{path}" GITHUB_RAW_URL = "https://raw.githubusercontent.com/KoslickiLab/YACHT/main/demo/{path}" BASE_URL = "https://farm.cse.ucdavis.edu/~ctbrown/sourmash-db/" -ZENODO_COMMUNITY_URL = "https://zenodo.org/api/records?q=communities:yacht&size=25" +ZENODO_COMMUNITY_URL = "https://zenodo.org/api/records/?communities=yacht&size=100" + +# Pythonic enum implementation for lambda adjustment status +from enum import Enum + +class AdjustStatusType(Enum): + """Status types for lambda adjustment.""" + LAMBDA = "lambda" + HIGH = "high" + LOW = "low" + NONE = "none" + +@dataclass(frozen=True) +class AdjustStatus: + """Lambda adjustment status with optional value.""" + status: AdjustStatusType + value: Optional[float] = None + + @classmethod + def lambda_value(cls, value: float): + """Create a Lambda status with a value.""" + return cls(AdjustStatusType.LAMBDA, value) + + @classmethod + def high(cls): + """Create a High status.""" + return cls(AdjustStatusType.HIGH) + + @classmethod + def low(cls): + """Create a Low status.""" + return cls(AdjustStatusType.LOW) + + @classmethod + def none(cls): + """Create a None status.""" + return cls(AdjustStatusType.NONE) + +# Class for cov_calc output +@dataclass +class AniResult: + naive_ani: float + final_est_ani: float + final_est_cov: float + seq_name: str + gn_name: str + contig_name: str + mean_cov: float + median_cov: float + containment_index: Tuple[int, int] + lambda_status: AdjustStatus + ani_ci: Tuple[Optional[float], Optional[float]] + lambda_ci: Tuple[Optional[float], Optional[float]] + genome_sketch: Any + rel_abund: Optional[float] + seq_abund: Optional[float] + kmers_lost: Optional[int] + +# Class for cov_calc arguments +# Note that these are being set here for now, but could be brought as CLI arguments for yacht +class _ContainArgs: + def __init__(self): + self.ci_int = True + self.ratio = True + self.mme = True + self.nb = True + self.mle = True + self.min_count_correct = 1 # Example value def load_signature_with_ksize(filename: str, ksize: int) -> sourmash.SourmashSignature: """ @@ -155,9 +238,10 @@ def run_yacht_train_core( selected_sig_files = pd.read_csv(os.path.join(path_to_temp_dir, 'selected_result.tsv'), sep="\t", header=None) selected_sig_files = selected_sig_files[0].to_list() - # get the mapping from signature file name to genome name - mapping = {sig_info_dict[name][-1]:name for name in sig_info_dict} - selected_genome_names_set = set([mapping[sig_file_path] for sig_file_path in selected_sig_files]) + # get the mapping from signature file name to genome name; normalize to basename for matching. Basename extracted from C++ + mapping = {os.path.basename(sig_info_dict[name][-1]): name for name in sig_info_dict} + selected_genome_names_set = set([mapping[os.path.basename(sig_file_path)] for sig_file_path in selected_sig_files]) + # remove the close related organisms from the reference genome list manifest_df = [] @@ -218,7 +302,13 @@ def collect_signature_info( ], ) - return {sig[1]: (sig[2], sig[3], sig[4], sig[5], sig[0]) for sig in tqdm(signatures) if sig} + def get_key_with_warning(sig): + if not sig[1]: + logger.warning(f"Signature has no name, using md5sum as identifier: {sig[2]}") + return sig[2] + return sig[1] + + return {get_key_with_warning(sig): (sig[2], sig[3], sig[4], sig[5], sig[0]) for sig in tqdm(signatures) if sig} class Prediction: @@ -478,23 +568,28 @@ def check_download_args(args, db_type): logger.error("We now haven't supported for virus database.") sys.exit(1) - def _decompress_and_remove(file_path: str) -> None: """ Decompresses a GZIP-compressed file and removes the original compressed file. :param file_path: The path to the .sig.gz file that needs to be decompressed and deleted. :return: None """ + import subprocess try: output_filename = os.path.splitext(file_path)[0] - with gzip.open(file_path, 'rb') as f_in: - with open(output_filename, 'wb') as f_out: - f_out.write(f_in.read()) - - os.remove(file_path) - + with open(output_filename, 'wb') as f_out: + result = subprocess.run( + ['gunzip', '-c', file_path], + stdout=f_out, + stderr=subprocess.PIPE + ) + if result.returncode == 0: + os.remove(file_path) + else: + logger.info(f"gunzip failed for {file_path}: {result.stderr.decode()}") except Exception as e: logger.info(f"Failed to process {file_path}: {e}") + def decompress_all_sig_files(sig_files: List[str], num_threads: int) -> None: """ @@ -507,3 +602,310 @@ def decompress_all_sig_files(sig_files: List[str], num_threads: int) -> None: p.map(_decompress_and_remove, sig_files) logger.info("All .sig.gz files have been decompressed.") + +def load_one_sig(sig_path: str, ksize: int): + """ + Imports a sourmash signature file using the path and the ksize. + """ + loaded_sig = load_one_signature(sig_path, ksize, select_moltype='DNA').to_mutable() + + logger.info("The sig file has been loaded." + ) + return(loaded_sig) + +def newton_raphson(ratio: float, mean: float, convergence: bool = True): + """ + Shaw and Yu (2024)'s implmentation of Newton-Raphson use to assist in the calculation of lambda. + """ + ratio = min(ratio, 1.0 - LAMBDA_EPSILON) + curr = mean / (1 - ratio) + + for _ in range(1000): + t1 = (1 - ratio) * curr + e_curr = math.exp(-curr) + t2 = mean * (1 - e_curr) + t3 = 1 - ratio + t4 = mean * e_curr + denom = t3 - t4 + if abs(denom) < LAMBDA_EPSILON: + break + prev = curr + curr = curr - (t1 - t2) / denom + if not math.isfinite(curr): + return None + if convergence and abs(curr - prev) < LAMBDA_EPSILON: + break + return curr + +def mle_zip(full_covs: list[int], _k: float, convergence: bool = True): + """ + Maximum likelihood estimator for the zero-inflated Poisson (ZIP) distribution from Shaw and Yu (2024) + """ + n_zero = 0 + count_set = Counter() #creating a new counter set + + for x in full_covs: + if x == 0: + n_zero += 1 + else: + count_set[x] +=1 + + if len(count_set) == 1: #If no info for inference, retuns None + return None + + if len(full_covs) - n_zero < SAMPLE_SIZE_CUTOFF: + return None + + mean = np.mean(full_covs) + nr_input = n_zero / len(full_covs) + lambda_out = newton_raphson(nr_input, mean, convergence) + + if lambda_out is None or lambda_out < 0 or not math.isfinite(lambda_out): + lambda_ret = None + else: + lambda_ret = lambda_out + return lambda_ret + +def variance(data: List[int]): + """ + An internal function that calculates the variance, which is the average of the squared differences of all values from the mean + """ + if len(data) == 0: + return None + mean = np.mean(data) + var = 0 + for x in data: + var += (float(x) - mean) * (float(x) - mean) + var_out = float(var / len(data)) + return var_out + +def ratio_lambda(full_covs: list[int], min_count_correct): + """ + Estimating lambda according to Shaw and Yu (2024) using the ratio method + """ + n_zero = 0 + count_map = defaultdict(int) # Creates an empty dictionary for integer values equivalent to FxHashMap::default() + for x in full_covs: + if x == 0: + n_zero += 1 + else: + count_map[x] +=1 + if len(count_map) == 1: # Absent info for inference, returns None. + return None + + if (len(full_covs) - n_zero) < SAMPLE_SIZE_CUTOFF: + return None + else: + sort_vec = [(value, key) for key, value in count_map.items()] + sort_vec.sort(reverse=True) #sorting the vector in reverse order + most_ind = sort_vec[0][1] + if (most_ind + 1) not in count_map: + return None + + count_p1 = count_map[(most_ind + 1)] + count = count_map[most_ind] + + if count_p1 < min_count_correct or count < min_count_correct: + return None + + lambda_out = count_p1 / (count * (most_ind + 1)) + return lambda_out + +def r_moments_lambda(m: float, v: float, lambda_out: float): + """ + Internal function used used in the calculation of lambda using ratio from moments + """ + result = lambda_out / (v - 1 + lambda_out + m) + return result + +def ratio_calc(val: float, r: float, lambda_out: float): + """ + Internal function used used in the calculation of lambda using ratio from moments + """ + if (r < 100): + return gamma(r + val + 1) / (val + 1) / gamma(r + val) * lambda_out / (r + lambda_out) + else: + return (r + val + 1) / (val + 1) * lambda_out / (r + lambda_out) + +def ratio_from_moments_lambda(val: float, lambda_out: float, m: float, v: float): + """ + Function that calculates lambda using the ratio from moments formula + """ + r = r_moments_lambda(m, v, lambda_out) + if r < 0: + return None + rat_return = ratio_calc(val, r, lambda_out) + return rat_return + +def mme_lambda(full_covs: list[int]) -> Optional[float]: + """ + Calculates the "method of moments" estimator for the lambda parameter + from a list of coverage values. + """ + num_zero = 0 + count_set = set() + + for x in full_covs: + if x == 0: + num_zero += 1 + else: + count_set.add(x) + if len(count_set) == 1: + return None + # Does the number of non-zero observations meet the cutoff threshold? + if len(full_covs) - num_zero < SAMPLE_SIZE_CUTOFF: + return None + # Calculating mean and variance + try: + mean_val = np.mean(full_covs) + variance_val = variance(full_covs) + except ValueError: + return None + # Calculating lambda using the MME formula + lambda_val = variance_val / mean_val + mean_val - 1.0 + + # Ensuring lambda is non-negative + if lambda_val < 0.0: + return None + else: + return lambda_val + +def binary_search_lambda(full_covs: list[int]): + if len(full_covs) == 0: + return None + m = np.mean(full_covs) + v = variance(full_covs) + nonzero = 0 + ones = 0 + twos = 0 + + for x in full_covs: + if x != 0: + nonzero += 1 + if x == 1: + ones += 1 + elif x == 2: + twos += 1 + + if ones == 0: + return None + ratio_est = float(twos) / float(ones) + + left = float(max(0.003, m - 2)) + right = m + 5 + best = None + best_val = 10000 + for i in range(10000): + test = (right - left)/10000 * float(i) + left + proposed = ratio_from_moments_lambda(1, test, m, v); + if proposed is not None: + p = proposed - ratio_est + if abs(p) < best_val: + best_val = abs(p) + best = test + + if best == None: + return None + #consider putting in a RaiseError statement here + r = r_moments_lambda(m, v, best) + # removed debugging code from sylph (see inference.rs for reference) + return best + +def bootstrap_interval(covs_full: list[int], k: float, args: _ContainArgs): + """ + This function calculates bootstrap confidence intervals for ANI and lambda. + """ + if args.ci_int == False: + return (None, None, None, None) + + num_samp = len(covs_full) + iters = 100 + res_ani = [] + res_lambda = [] + + for _ in range(iters): + rand_vec = [] + for _ in range(num_samp): + rand_vec.append(random.choice(covs_full)) + if args.ratio: + lambda_val = ratio_lambda(rand_vec, args.min_count_correct) + elif args.mme: + lambda_val = mme_lambda(rand_vec) + elif args.nb: + lambda_val = binary_search_lambda(rand_vec) + elif args.mle: + lambda_val = mle_zip(rand_vec, ksize) + else: + lambda_val = ratio_lambda(rand_vec, args.min_count_correct) + + ani_val = ani_from_lambda(lambda_val, np.mean(rand_vec), ksize, rand_vec) + + if ani_val is not None and lambda_val is not None: + if not np.isnan(ani_val) and not np.isnan(lambda_val): + res_ani.append(ani_val) + res_lambda.append(lambda_val) + + res_ani.sort() + res_lambda.sort() + + if len(res_ani) < 50: + return (None, None, None, None) + + suc = len(res_ani) + low_ani = res_ani[suc * 5 // 100] + high_ani = res_ani[suc * 95 // 100] + low_lambda = res_lambda[suc * 5 // 100] + high_lambda = res_lambda[suc * 95 // 100] + + return (low_ani, high_ani, low_lambda, high_lambda) + +def ani_from_lambda(lambda_val, lam_mean, k_value, full_cov): + """ + Calculates an adjusted ani value + Args: + lambda_val: An optional float used to calculate ani + lam_mean: A float value (unused in the original logic). + k: A float value used as the inverse exponent. + full_cov: A list of integers to analyze for non-zero counts. + + Returns: + An optional float representing the calculated adjusted index 'ani', + or None if the input lambda is None, or if 'ani' is negative or NaN. + """ + if lambda_val == None: + return None + + # Check if lambda is too close to zero so that we avoid dividing by zero + if abs(lambda_val) < LAMBDA_EPSILON: + return None + + contain_count = 0 + zero_count = 0 + for x in full_cov: + if x != 0: + contain_count += 1 + + else: + zero_count += 1 + + if not full_cov: + return None + + adj_index = contain_count / (1.0 - math.exp(-lambda_val)) / len(full_cov) + + # Calculating ani using math.pow for clarity (and to maintain type correctness) + ani = math.pow(adj_index, 1.0 / k_value) + + # Caps ANI at 1.0 (100% identity, so impossible to exceed this) + # This can happen in instances with very low lambda values where the adjustment inflates ANI + if ani > 1.0: + logger.debug(f"ANI {ani:.6f} exceeds 1.0 (lambda={lambda_val:.4f}, adj_index={adj_index:.4f}), capping at 1.0") + ani = 1.0 + + if ani < 0.0 or math.isnan(ani): + ret_ani = None + + else: + ret_ani = ani + + return ret_ani diff --git a/tests/internal_superyacht_test.py b/tests/internal_superyacht_test.py new file mode 100755 index 00000000..d3f66664 --- /dev/null +++ b/tests/internal_superyacht_test.py @@ -0,0 +1,71 @@ +import pandas as pd +import zipfile +import glob +import sys +import os +import yacht +import sourmash +from multiprocessing import Pool +import multiprocessing +from yacht.hypothesis_recovery_src import get_exclusive_hashes +from yacht.hypothesis_recovery_src import get_organisms_with_nonzero_overlap +from yacht.hypothesis_recovery_src import hypothesis_recovery +from yacht.utils import decompress_all_sig_files +from yacht.utils import load_signature_with_ksize +from yacht.utils import load_one_sig +""" +A script that reproduces the results of hypothesis_recovery_src so that the sylph coverage model can be incorporated +""" +manifestFilePath="/Users/randolph.raborn/Desktop/YACHT/demo/query_data/gtdb-rs214-reps.k31_0.95_pretrained/gtdb-rs214-reps.k31_0.95_processed_manifest.tsv" +sampleSigPath="/Users/randolph.raborn/Desktop/YACHT/demo/new_demo_files/SRR6940089_sample_out.sig.zip" +genTempDir="/Users/randolph.raborn/Desktop/YACHT/demo/query_data/gtdb_ani_thresh_0.95_intermediate_files/" +sampleTempDir="/Users/randolph.raborn/Desktop/YACHT/demo/new_demo_files/genome_reference.sig_temp" +multisearch_result_file = "/Users/randolph.raborn/Desktop/YACHT/demo/query_data/gtdb-rs214-reps.k31_0.95_pretrained/gtdb-rs214-reps.k31_0.95_intermediate_files/training_multisearch_result.csv" +multisearchResultPath= "/Users/randolph.raborn/Desktop/YACHT/demo/new_demo_files/genome_reference.sig_temp/sample_multisearch_result.csv" +ksize=31 +min_cov=0.10 #setting to 0.10 + +sample_sig = load_signature_with_ksize(sampleSigPath, ksize) +print(sample_sig) +sample_info_set = (sampleSigPath, sample_sig) + +#importing the manifest: +manifest = pd.read_csv(manifestFilePath, sep="\t", header=0) +#print(manifest.head()) + +#following the procedure at the end of yacht's `get_organisms_with_nonzero_overlap` +#multisearch_result = pd.read_csv( +# multisearch_result_file, +# sep=",", +# header=0, +# ) + +#print(multisearch_result) +#multisearch_result_new = multisearch_result.drop_duplicates().reset_index(drop=True) +#print(multisearch_result_new) +#multisearch_result_names = multisearch_result["match_name"].to_list() #this is what is actually returned by get_organisms_with_nonzero_overlap +#print(set(multisearch_result_names)) + +multiprocessing.set_start_method('fork') +#multisearch_result_new2 = get_organisms_with_nonzero_overlap(manifest, sampleSigPath, 1000, 31, 2, genTempDir, sampleTempDir) #comment this out to import directly from the multisearch result file +#print(multisearch_result_new2) +#multisearch_result_new2_file = pd.read_csv(multisearchResultPath, sep=",", header=0) +#multisearch_result_new2_file2 = multisearch_result_new2_file.drop_duplicates().reset_index(drop=True) + +#print(f"Manifest") +#print(manifest.head()) +#print(f"Multisearch_object") +#print(multisearch_result_new2_file2.head()) + +#multisearch_result_names2 = multisearch_result_new2_file["match_name"].to_list() +print(f"Made it here") +#print(multisearch_result_names2) +#exclusive_hashes_out = get_exclusive_hashes(manifest, multisearch_result_names2, sample_sig, ksize, genTempDir) +#print(exclusive_hashes_out) + +#creating a minimum coverage list +min_cov_list = [min_cov] * len(manifest) + +hyp_rec_out = hypothesis_recovery(manifest=manifest, sample_info_set=sample_info_set, path_to_genome_temp_dir=genTempDir, min_coverage_list=min_cov_list, scale=1000, ksize=ksize, ani_thresh=0.85, num_threads=16) + +print(hyp_rec_out) \ No newline at end of file