Skip to content

Commit 6ffee38

Browse files
authored
add nested classification
1 parent 6c3e29f commit 6ffee38

1 file changed

Lines changed: 133 additions & 36 deletions

File tree

src/classify.rs

Lines changed: 133 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
//! This module handles the `classify` subcommand.
2+
//! MODIFIED: Includes strictly corrected logic for nested lineage classification.
23
34
use crate::errors::{AppError, AppResult};
45
use bio::io::fasta;
@@ -10,7 +11,7 @@ use indicatif::{ProgressBar, ProgressStyle};
1011
use log::{debug, error, info, trace, warn};
1112
use rayon::prelude::*;
1213
use rust_lapper::{Interval, Lapper};
13-
use std::collections::HashMap;
14+
use std::collections::{HashMap, HashSet};
1415
use std::fs::File;
1516
use std::io::{BufWriter, Write};
1617
use std::path::Path;
@@ -49,6 +50,11 @@ pub struct Args {
4950
/// Number of threads for parallel processing.
5051
#[arg(short = 't', long = "threads")]
5152
pub num_cpu: Option<usize>,
53+
54+
/// [NEW] Enable nested lineage classification logic.
55+
/// This requires markers to be defined with multiple columns for each level.
56+
#[arg(long)]
57+
pub nested_classification: bool,
5258
}
5359

5460
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -126,6 +132,7 @@ fn find_markers(
126132
matched_markers
127133
}
128134

135+
// MODIFIED: This function now reads multiple lineage columns.
129136
fn get_positions(
130137
tsv_file: &str,
131138
) -> AppResult<(HashMap<usize, String>, HashMap<usize, String>)> {
@@ -134,20 +141,26 @@ fn get_positions(
134141
let mut markers_lineage = HashMap::new();
135142
for result in rdr.records() {
136143
let record = result?;
137-
if record.len() < 3 {
144+
if record.len() < 4 { // pos, ref, alt, level1
138145
continue;
139146
}
140147
let pos: usize = record[0].parse().map_err(|e| AppError::Parsing(format!("Invalid position in TSV: {}", e)))?;
141-
let alt_base = record[1].to_string();
142-
let lineage = record[2].to_string();
148+
let alt_base = record[2].to_string(); // ALT is the 3rd column
149+
150+
// Read all lineage columns and join them with a semicolon.
151+
let lineage_path = record.iter()
152+
.skip(3) // Start from the first lineage column
153+
.take_while(|&s| !s.trim().is_empty()) // Take until an empty cell is found
154+
.collect::<Vec<&str>>()
155+
.join(";");
156+
143157
reference_positions.insert(pos, alt_base);
144-
markers_lineage.insert(pos, lineage);
158+
markers_lineage.insert(pos, lineage_path);
145159
}
146160
Ok((reference_positions, markers_lineage))
147161
}
148162

149163
fn get_ref(fasta_file: &str) -> AppResult<String> {
150-
// FIX: Manually map the error from anyhow::Error to our AppError
151164
let reader = fasta::Reader::from_file(fasta_file)
152165
.map_err(|e| AppError::Generic(format!("Failed to open FASTA file {}: {}", fasta_file, e)))?;
153166
let mut records = reader.records();
@@ -224,7 +237,6 @@ fn get_genomepaths(
224237
}
225238

226239
fn get_genomes_from_fasta(fasta_file: &str) -> AppResult<HashMap<String, String>> {
227-
// FIX: Manually map the error from anyhow::Error to our AppError
228240
let reader = fasta::Reader::from_file(fasta_file)
229241
.map_err(|e| AppError::Generic(format!("Failed to open FASTA file {}: {}", fasta_file, e)))?;
230242
let mut genomes = HashMap::new();
@@ -276,7 +288,6 @@ fn analyze_genome(
276288
for (kmer, (position, ref_position, lineage)) in matched_markers {
277289
let snp_position = position + k / 2;
278290
let (gene_id, aa_pos, aa_change) = if let Some(tree) = &annotations {
279-
// Find gene overlapping with the SNP
280291
if let Some(gene_interval) = tree.find(snp_position, snp_position + 1).next() {
281292
let alt_base = &genome_seq[snp_position..snp_position + 1];
282293
translate_snp_info(gene_interval, snp_position, alt_base, &genome_seq)
@@ -381,7 +392,6 @@ fn parse_gff_and_build_tree(gff_file: &str) -> AppResult<Lapper<usize, Gene>> {
381392
}
382393
}
383394
}
384-
// FIX: Get the length *before* moving the value
385395
let num_intervals = intervals.len();
386396
let lapper = Lapper::new(intervals);
387397
debug!("Successfully parsed {} CDS features from GFF.", num_intervals);
@@ -549,51 +559,140 @@ fn process_genomes(
549559
Ok(results)
550560
}
551561

562+
/// [CORRECTED LOGIC] This function now counts all levels of the hierarchy for each marker.
552563
fn generate_summary(results: &[String]) -> HashMap<String, HashMap<String, usize>> {
553-
let mut lineage_count_map = HashMap::new();
564+
let mut lineage_count_map: HashMap<String, HashMap<String, usize>> = HashMap::new();
554565

555566
for line in results {
556567
let fields: Vec<&str> = line.trim_end().split('\t').collect();
557568
if fields.len() < 6 || fields[5].is_empty() {
558569
continue;
559570
}
560571
let genome_name = fields[0].to_string();
561-
let lineage = fields[5].to_string();
562-
563-
lineage_count_map
564-
.entry(genome_name)
565-
.or_insert_with(HashMap::new)
566-
.entry(lineage)
567-
.and_modify(|c| *c += 1)
568-
.or_insert(1);
572+
let lineage_path = fields[5].to_string();
573+
574+
// Add count for each level in this marker's hierarchy
575+
let mut current_path_part = String::new();
576+
for (i, component) in lineage_path.split(';').enumerate() {
577+
current_path_part = if i == 0 {
578+
component.to_string()
579+
} else {
580+
format!("{};{}", current_path_part, component)
581+
};
582+
*lineage_count_map
583+
.entry(genome_name.clone())
584+
.or_default()
585+
.entry(current_path_part.clone())
586+
.or_default() += 1;
587+
}
569588
}
570589

571590
lineage_count_map
572591
}
573592

593+
/// [REWRITTEN] Final logic for nested classification.
594+
/// 1. Finds the deepest valid hierarchical path.
595+
/// 2. Finds the most abundant lineage overall.
596+
/// 3. If the most abundant is from a different branch and has more support, it wins.
597+
/// 4. Otherwise, the deepest path wins.
598+
fn get_final_lineage_call(lineage_counts: &HashMap<String, usize>) -> String {
599+
if lineage_counts.is_empty() {
600+
return "Unclassified".to_string();
601+
}
602+
603+
let supported_lineages: HashSet<String> = lineage_counts.keys().cloned().collect();
604+
let mut valid_candidates = Vec::new();
605+
606+
// 1. Identify all lineages with a valid, fully supported path.
607+
for candidate in supported_lineages.iter() {
608+
let mut is_path_valid = true;
609+
if candidate.contains(';') {
610+
let components: Vec<&str> = candidate.split(';').collect();
611+
for i in 1..components.len() {
612+
let parent_path = components[0..i].join(";");
613+
if !supported_lineages.contains(&parent_path) {
614+
is_path_valid = false;
615+
break;
616+
}
617+
}
618+
}
619+
if is_path_valid {
620+
valid_candidates.push(candidate.clone());
621+
}
622+
}
623+
624+
// If no valid paths exist, fallback to the most abundant lineage.
625+
if valid_candidates.is_empty() {
626+
return lineage_counts
627+
.iter()
628+
.max_by_key(|&(_, count)| count)
629+
.map(|(lineage, _)| lineage.clone())
630+
.unwrap_or_else(|| "Unclassified".to_string());
631+
}
632+
633+
// 2. From the valid candidates, find the deepest one.
634+
// If there's a tie in depth, the higher SNP count for that specific deep lineage wins.
635+
let best_deep_lineage = valid_candidates
636+
.iter()
637+
.max_by(|a, b| {
638+
let depth_a = a.split(';').count();
639+
let depth_b = b.split(';').count();
640+
depth_a.cmp(&depth_b)
641+
.then_with(|| lineage_counts[*a].cmp(&lineage_counts[*b]))
642+
})
643+
.unwrap(); // Safe because valid_candidates is not empty.
644+
645+
// 3. Find the most abundant lineage overall.
646+
let most_abundant_lineage = lineage_counts
647+
.iter()
648+
.max_by_key(|&(_, count)| count)
649+
.map(|(lineage, _)| lineage)
650+
.unwrap(); // Safe because lineage_counts is not empty.
651+
652+
// 4. The final decision logic.
653+
let best_deep_count = lineage_counts[best_deep_lineage];
654+
let most_abundant_count = lineage_counts[most_abundant_lineage];
655+
656+
// Check if the most abundant lineage is from a different branch than the deepest one.
657+
// A shared branch means one starts with the other.
658+
let shares_branch = best_deep_lineage.starts_with(most_abundant_lineage) ||
659+
most_abundant_lineage.starts_with(best_deep_lineage);
660+
661+
if !shares_branch && most_abundant_count > best_deep_count {
662+
// The most abundant lineage is from a different branch and has more support, so it wins.
663+
most_abundant_lineage.clone()
664+
} else {
665+
// Otherwise, prioritize the deepest valid path.
666+
best_deep_lineage.clone()
667+
}
668+
}
669+
670+
574671
fn write_summary(
575672
summary_out: &mut BufWriter<File>,
576673
genome: String,
577-
lineage_counts: Vec<(String, usize)>,
674+
lineage_counts: &HashMap<String, usize>,
675+
nested_classification: bool,
578676
) -> AppResult<()> {
579-
let lineage_count_str = lineage_counts
677+
678+
let mut sorted_lineages: Vec<(&String, &usize)> = lineage_counts.iter().collect();
679+
sorted_lineages.sort_by(|a, b| b.1.cmp(a.1).then_with(|| a.0.cmp(b.0)));
680+
681+
let lineage_count_str = sorted_lineages
580682
.iter()
581683
.map(|(lin, cnt)| format!("{}:{}", lin, cnt))
582684
.collect::<Vec<_>>()
583-
.join(",");
584-
585-
let majority_lineage = if lineage_counts.is_empty() {
586-
String::new()
587-
} else if lineage_counts.len() == 1 {
588-
lineage_counts[0].0.clone()
589-
} else if lineage_counts[0].1 > lineage_counts[1].1 {
590-
lineage_counts[0].0.clone()
685+
.join(" ");
686+
687+
let majority_lineage = if nested_classification {
688+
get_final_lineage_call(lineage_counts)
591689
} else {
592-
lineage_counts
593-
.iter()
594-
.map(|(lin, _)| lin.clone())
595-
.collect::<Vec<_>>()
596-
.join(",")
690+
// Original logic
691+
if sorted_lineages.is_empty() {
692+
"Unclassified".to_string()
693+
} else {
694+
sorted_lineages[0].0.clone()
695+
}
597696
};
598697

599698
writeln!(
@@ -648,9 +747,7 @@ pub fn run(args: Args) -> AppResult<()> {
648747
let lineage_count_map = generate_summary(&results);
649748

650749
for (genome, lineage_map) in lineage_count_map {
651-
let mut lineage_counts: Vec<(String, usize)> = lineage_map.into_iter().collect();
652-
lineage_counts.sort_by(|a, b| b.1.cmp(&a.1));
653-
write_summary(&mut summary_out, genome, lineage_counts)?;
750+
write_summary(&mut summary_out, genome, &lineage_map, args.nested_classification)?;
654751
}
655752
info!("Classification complete.");
656753
Ok(())

0 commit comments

Comments
 (0)