11//! This module handles the `classify` subcommand.
2+ //! MODIFIED: Includes strictly corrected logic for nested lineage classification.
23
34use crate :: errors:: { AppError , AppResult } ;
45use bio:: io:: fasta;
@@ -10,7 +11,7 @@ use indicatif::{ProgressBar, ProgressStyle};
1011use log:: { debug, error, info, trace, warn} ;
1112use rayon:: prelude:: * ;
1213use rust_lapper:: { Interval , Lapper } ;
13- use std:: collections:: HashMap ;
14+ use std:: collections:: { HashMap , HashSet } ;
1415use std:: fs:: File ;
1516use std:: io:: { BufWriter , Write } ;
1617use std:: path:: Path ;
@@ -49,6 +50,11 @@ pub struct Args {
4950 /// Number of threads for parallel processing.
5051 #[ arg( short = 't' , long = "threads" ) ]
5152 pub num_cpu : Option < usize > ,
53+
54+ /// [NEW] Enable nested lineage classification logic.
55+ /// This requires markers to be defined with multiple columns for each level.
56+ #[ arg( long) ]
57+ pub nested_classification : bool ,
5258}
5359
5460#[ derive( Clone , Debug , PartialEq , Eq ) ]
@@ -126,6 +132,7 @@ fn find_markers(
126132 matched_markers
127133}
128134
135+ // MODIFIED: This function now reads multiple lineage columns.
129136fn get_positions (
130137 tsv_file : & str ,
131138) -> AppResult < ( HashMap < usize , String > , HashMap < usize , String > ) > {
@@ -134,20 +141,26 @@ fn get_positions(
134141 let mut markers_lineage = HashMap :: new ( ) ;
135142 for result in rdr. records ( ) {
136143 let record = result?;
137- if record. len ( ) < 3 {
144+ if record. len ( ) < 4 { // pos, ref, alt, level1
138145 continue ;
139146 }
140147 let pos: usize = record[ 0 ] . parse ( ) . map_err ( |e| AppError :: Parsing ( format ! ( "Invalid position in TSV: {}" , e) ) ) ?;
141- let alt_base = record[ 1 ] . to_string ( ) ;
142- let lineage = record[ 2 ] . to_string ( ) ;
148+ let alt_base = record[ 2 ] . to_string ( ) ; // ALT is the 3rd column
149+
150+ // Read all lineage columns and join them with a semicolon.
151+ let lineage_path = record. iter ( )
152+ . skip ( 3 ) // Start from the first lineage column
153+ . take_while ( |& s| !s. trim ( ) . is_empty ( ) ) // Take until an empty cell is found
154+ . collect :: < Vec < & str > > ( )
155+ . join ( ";" ) ;
156+
143157 reference_positions. insert ( pos, alt_base) ;
144- markers_lineage. insert ( pos, lineage ) ;
158+ markers_lineage. insert ( pos, lineage_path ) ;
145159 }
146160 Ok ( ( reference_positions, markers_lineage) )
147161}
148162
149163fn get_ref ( fasta_file : & str ) -> AppResult < String > {
150- // FIX: Manually map the error from anyhow::Error to our AppError
151164 let reader = fasta:: Reader :: from_file ( fasta_file)
152165 . map_err ( |e| AppError :: Generic ( format ! ( "Failed to open FASTA file {}: {}" , fasta_file, e) ) ) ?;
153166 let mut records = reader. records ( ) ;
@@ -224,7 +237,6 @@ fn get_genomepaths(
224237}
225238
226239fn get_genomes_from_fasta ( fasta_file : & str ) -> AppResult < HashMap < String , String > > {
227- // FIX: Manually map the error from anyhow::Error to our AppError
228240 let reader = fasta:: Reader :: from_file ( fasta_file)
229241 . map_err ( |e| AppError :: Generic ( format ! ( "Failed to open FASTA file {}: {}" , fasta_file, e) ) ) ?;
230242 let mut genomes = HashMap :: new ( ) ;
@@ -276,7 +288,6 @@ fn analyze_genome(
276288 for ( kmer, ( position, ref_position, lineage) ) in matched_markers {
277289 let snp_position = position + k / 2 ;
278290 let ( gene_id, aa_pos, aa_change) = if let Some ( tree) = & annotations {
279- // Find gene overlapping with the SNP
280291 if let Some ( gene_interval) = tree. find ( snp_position, snp_position + 1 ) . next ( ) {
281292 let alt_base = & genome_seq[ snp_position..snp_position + 1 ] ;
282293 translate_snp_info ( gene_interval, snp_position, alt_base, & genome_seq)
@@ -381,7 +392,6 @@ fn parse_gff_and_build_tree(gff_file: &str) -> AppResult<Lapper<usize, Gene>> {
381392 }
382393 }
383394 }
384- // FIX: Get the length *before* moving the value
385395 let num_intervals = intervals. len ( ) ;
386396 let lapper = Lapper :: new ( intervals) ;
387397 debug ! ( "Successfully parsed {} CDS features from GFF." , num_intervals) ;
@@ -549,51 +559,140 @@ fn process_genomes(
549559 Ok ( results)
550560}
551561
562+ /// [CORRECTED LOGIC] This function now counts all levels of the hierarchy for each marker.
552563fn generate_summary ( results : & [ String ] ) -> HashMap < String , HashMap < String , usize > > {
553- let mut lineage_count_map = HashMap :: new ( ) ;
564+ let mut lineage_count_map: HashMap < String , HashMap < String , usize > > = HashMap :: new ( ) ;
554565
555566 for line in results {
556567 let fields: Vec < & str > = line. trim_end ( ) . split ( '\t' ) . collect ( ) ;
557568 if fields. len ( ) < 6 || fields[ 5 ] . is_empty ( ) {
558569 continue ;
559570 }
560571 let genome_name = fields[ 0 ] . to_string ( ) ;
561- let lineage = fields[ 5 ] . to_string ( ) ;
562-
563- lineage_count_map
564- . entry ( genome_name)
565- . or_insert_with ( HashMap :: new)
566- . entry ( lineage)
567- . and_modify ( |c| * c += 1 )
568- . or_insert ( 1 ) ;
572+ let lineage_path = fields[ 5 ] . to_string ( ) ;
573+
574+ // Add count for each level in this marker's hierarchy
575+ let mut current_path_part = String :: new ( ) ;
576+ for ( i, component) in lineage_path. split ( ';' ) . enumerate ( ) {
577+ current_path_part = if i == 0 {
578+ component. to_string ( )
579+ } else {
580+ format ! ( "{};{}" , current_path_part, component)
581+ } ;
582+ * lineage_count_map
583+ . entry ( genome_name. clone ( ) )
584+ . or_default ( )
585+ . entry ( current_path_part. clone ( ) )
586+ . or_default ( ) += 1 ;
587+ }
569588 }
570589
571590 lineage_count_map
572591}
573592
593+ /// [REWRITTEN] Final logic for nested classification.
594+ /// 1. Finds the deepest valid hierarchical path.
595+ /// 2. Finds the most abundant lineage overall.
596+ /// 3. If the most abundant is from a different branch and has more support, it wins.
597+ /// 4. Otherwise, the deepest path wins.
598+ fn get_final_lineage_call ( lineage_counts : & HashMap < String , usize > ) -> String {
599+ if lineage_counts. is_empty ( ) {
600+ return "Unclassified" . to_string ( ) ;
601+ }
602+
603+ let supported_lineages: HashSet < String > = lineage_counts. keys ( ) . cloned ( ) . collect ( ) ;
604+ let mut valid_candidates = Vec :: new ( ) ;
605+
606+ // 1. Identify all lineages with a valid, fully supported path.
607+ for candidate in supported_lineages. iter ( ) {
608+ let mut is_path_valid = true ;
609+ if candidate. contains ( ';' ) {
610+ let components: Vec < & str > = candidate. split ( ';' ) . collect ( ) ;
611+ for i in 1 ..components. len ( ) {
612+ let parent_path = components[ 0 ..i] . join ( ";" ) ;
613+ if !supported_lineages. contains ( & parent_path) {
614+ is_path_valid = false ;
615+ break ;
616+ }
617+ }
618+ }
619+ if is_path_valid {
620+ valid_candidates. push ( candidate. clone ( ) ) ;
621+ }
622+ }
623+
624+ // If no valid paths exist, fallback to the most abundant lineage.
625+ if valid_candidates. is_empty ( ) {
626+ return lineage_counts
627+ . iter ( )
628+ . max_by_key ( |& ( _, count) | count)
629+ . map ( |( lineage, _) | lineage. clone ( ) )
630+ . unwrap_or_else ( || "Unclassified" . to_string ( ) ) ;
631+ }
632+
633+ // 2. From the valid candidates, find the deepest one.
634+ // If there's a tie in depth, the higher SNP count for that specific deep lineage wins.
635+ let best_deep_lineage = valid_candidates
636+ . iter ( )
637+ . max_by ( |a, b| {
638+ let depth_a = a. split ( ';' ) . count ( ) ;
639+ let depth_b = b. split ( ';' ) . count ( ) ;
640+ depth_a. cmp ( & depth_b)
641+ . then_with ( || lineage_counts[ * a] . cmp ( & lineage_counts[ * b] ) )
642+ } )
643+ . unwrap ( ) ; // Safe because valid_candidates is not empty.
644+
645+ // 3. Find the most abundant lineage overall.
646+ let most_abundant_lineage = lineage_counts
647+ . iter ( )
648+ . max_by_key ( |& ( _, count) | count)
649+ . map ( |( lineage, _) | lineage)
650+ . unwrap ( ) ; // Safe because lineage_counts is not empty.
651+
652+ // 4. The final decision logic.
653+ let best_deep_count = lineage_counts[ best_deep_lineage] ;
654+ let most_abundant_count = lineage_counts[ most_abundant_lineage] ;
655+
656+ // Check if the most abundant lineage is from a different branch than the deepest one.
657+ // A shared branch means one starts with the other.
658+ let shares_branch = best_deep_lineage. starts_with ( most_abundant_lineage) ||
659+ most_abundant_lineage. starts_with ( best_deep_lineage) ;
660+
661+ if !shares_branch && most_abundant_count > best_deep_count {
662+ // The most abundant lineage is from a different branch and has more support, so it wins.
663+ most_abundant_lineage. clone ( )
664+ } else {
665+ // Otherwise, prioritize the deepest valid path.
666+ best_deep_lineage. clone ( )
667+ }
668+ }
669+
670+
574671fn write_summary (
575672 summary_out : & mut BufWriter < File > ,
576673 genome : String ,
577- lineage_counts : Vec < ( String , usize ) > ,
674+ lineage_counts : & HashMap < String , usize > ,
675+ nested_classification : bool ,
578676) -> AppResult < ( ) > {
579- let lineage_count_str = lineage_counts
677+
678+ let mut sorted_lineages: Vec < ( & String , & usize ) > = lineage_counts. iter ( ) . collect ( ) ;
679+ sorted_lineages. sort_by ( |a, b| b. 1 . cmp ( a. 1 ) . then_with ( || a. 0 . cmp ( b. 0 ) ) ) ;
680+
681+ let lineage_count_str = sorted_lineages
580682 . iter ( )
581683 . map ( |( lin, cnt) | format ! ( "{}:{}" , lin, cnt) )
582684 . collect :: < Vec < _ > > ( )
583- . join ( "," ) ;
584-
585- let majority_lineage = if lineage_counts. is_empty ( ) {
586- String :: new ( )
587- } else if lineage_counts. len ( ) == 1 {
588- lineage_counts[ 0 ] . 0 . clone ( )
589- } else if lineage_counts[ 0 ] . 1 > lineage_counts[ 1 ] . 1 {
590- lineage_counts[ 0 ] . 0 . clone ( )
685+ . join ( " " ) ;
686+
687+ let majority_lineage = if nested_classification {
688+ get_final_lineage_call ( lineage_counts)
591689 } else {
592- lineage_counts
593- . iter ( )
594- . map ( |( lin, _) | lin. clone ( ) )
595- . collect :: < Vec < _ > > ( )
596- . join ( "," )
690+ // Original logic
691+ if sorted_lineages. is_empty ( ) {
692+ "Unclassified" . to_string ( )
693+ } else {
694+ sorted_lineages[ 0 ] . 0 . clone ( )
695+ }
597696 } ;
598697
599698 writeln ! (
@@ -648,9 +747,7 @@ pub fn run(args: Args) -> AppResult<()> {
648747 let lineage_count_map = generate_summary ( & results) ;
649748
650749 for ( genome, lineage_map) in lineage_count_map {
651- let mut lineage_counts: Vec < ( String , usize ) > = lineage_map. into_iter ( ) . collect ( ) ;
652- lineage_counts. sort_by ( |a, b| b. 1 . cmp ( & a. 1 ) ) ;
653- write_summary ( & mut summary_out, genome, lineage_counts) ?;
750+ write_summary ( & mut summary_out, genome, & lineage_map, args. nested_classification ) ?;
654751 }
655752 info ! ( "Classification complete." ) ;
656753 Ok ( ( ) )
0 commit comments