11//! src/common.rs
22//!
3- //! Module for shared data structures and utility functions across the application.
4- //!
5- //! This module centralizes components like the model bundle, vectorizer,
6- //! and label encoder to avoid code duplication between the `train` and `predict`
7- //! modules.
3+ //! This module provides shared, high-performance data structures and utility
4+ //! functions for the `pathotypr` application. It has been optimized to reduce
5+ //! memory allocations and improve processing speed, especially for the `train`
6+ //! and `predict` subcommands.
87
8+ // --- Crates for performance and serialization ---
99use rayon:: prelude:: * ;
10+ use rustc_hash:: FxHashMap ; // A much faster hasher for integer keys
1011use serde:: { Deserialize , Serialize } ;
11- // We only need the Decision Tree, not the whole Random Forest
12- use smartcore:: tree:: decision_tree_classifier:: DecisionTreeClassifier ;
12+ use needletail:: Sequence ; // Import the Sequence trait to use its methods
13+
14+ // --- SmartCore components for machine learning ---
15+ // MODIFIED: All types now use f32 for memory efficiency.
1316use smartcore:: linalg:: basic:: matrix:: DenseMatrix ;
14- use std :: collections :: HashMap ;
17+ use smartcore :: tree :: decision_tree_classifier :: DecisionTreeClassifier ;
1518
1619// --- Model Bundle Structs ---
1720
18- /// Configuration for the model, such as the k-mer size .
21+ /// Defines the configuration of a trained model .
1922#[ derive( Serialize , Deserialize , Debug ) ]
2023pub struct ModelConfig {
24+ pub pathotypr_version : String ,
2125 pub kmer_size : usize ,
22- // Add the number of trees to the config
2326 pub n_trees : u16 ,
2427}
2528
26- /// A unified bundle containing everything needed for prediction.
29+ /// A unified, compressed bundle containing everything needed for prediction.
2730#[ derive( Serialize , Deserialize , Debug ) ]
2831pub struct ModelBundle {
2932 pub config : ModelConfig ,
3033 pub vectorizer : CountVectorizer ,
3134 pub label_encoder : LabelEncoder ,
32- // CHANGE: We now store a vector of individual decision trees
33- pub trees : Vec < DecisionTreeClassifier < f64 , usize , DenseMatrix < f64 > , Vec < usize > > > ,
35+ // MODIFIED: The DecisionTreeClassifier now uses f32.
36+ pub trees : Vec < DecisionTreeClassifier < f32 , usize , DenseMatrix < f32 > , Vec < usize > > > ,
3437}
3538
3639// --- Feature Processing Components ---
3740
38- /// Transforms text into k-mer count vectors.
41+ /// Transforms sequences into k-mer count vectors using `u64` representation .
3942#[ derive( Serialize , Deserialize , Debug ) ]
4043pub struct CountVectorizer {
41- pub vocabulary : HashMap < String , usize > ,
42- pub feature_names : Vec < String > ,
44+ pub vocabulary : FxHashMap < u64 , usize > ,
45+ pub num_features : usize ,
4346}
4447
4548impl CountVectorizer {
49+ /// Creates a new, empty `CountVectorizer`.
4650 pub fn new ( ) -> Self {
4751 Self {
48- vocabulary : HashMap :: new ( ) ,
49- feature_names : Vec :: new ( ) ,
52+ vocabulary : FxHashMap :: default ( ) ,
53+ num_features : 0 ,
5054 }
5155 }
5256
53- /// Builds the vocabulary from a collection of texts.
54- pub fn fit < T : AsRef < str > > ( & mut self , texts : & [ T ] ) {
55- let mut freq: HashMap < String , usize > = HashMap :: new ( ) ;
56- for text in texts {
57- for token in text. as_ref ( ) . split_whitespace ( ) {
58- * freq. entry ( token. to_string ( ) ) . or_insert ( 0 ) += 1 ;
59- }
57+ /// Builds the vocabulary from a pre-computed map of k-mer counts.
58+ pub fn fit ( & mut self , kmer_counts : & FxHashMap < u64 , u32 > ) {
59+ let mut vocab_idx = 0 ;
60+ for & kmer_hash in kmer_counts. keys ( ) {
61+ self . vocabulary . insert ( kmer_hash, vocab_idx) ;
62+ vocab_idx += 1 ;
6063 }
61- let mut freq_vec: Vec < ( String , usize ) > = freq. into_iter ( ) . collect ( ) ;
62- freq_vec. sort_by ( |a, b| b. 1 . cmp ( & a. 1 ) ) ;
63- self . vocabulary = freq_vec
64- . iter ( )
65- . enumerate ( )
66- . map ( |( i, ( token, _) ) | ( token. clone ( ) , i) )
67- . collect ( ) ;
68- self . feature_names = freq_vec. into_iter ( ) . map ( |( token, _) | token) . collect ( ) ;
64+ self . num_features = self . vocabulary . len ( ) ;
6965 }
7066
71- /// Transforms a collection of texts into a feature matrix.
72- pub fn transform < T : AsRef < str > + Sync > ( & self , texts : & [ T ] ) -> Vec < Vec < f64 > > {
73- texts
67+ /// MODIFIED: Transforms sequences into a sparse data format using f32.
68+ /// Instead of a giant Vec<Vec<f64>>, this returns a Vec of sparse vectors.
69+ /// Each sparse vector is a Vec of (feature_index, count).
70+ pub fn transform_sparse ( & self , sequences : & [ String ] , k : usize ) -> Vec < Vec < ( usize , f32 ) > > {
71+ sequences
7472 . par_iter ( )
75- . map ( |text| {
76- let mut counts = vec ! [ 0.0 ; self . vocabulary. len( ) ] ;
77- for token in text. as_ref ( ) . split_whitespace ( ) {
78- if let Some ( & idx) = self . vocabulary . get ( token) {
79- counts[ idx] += 1.0 ;
73+ . map ( |seq| {
74+ // Use a temporary HashMap to count k-mers for this sequence only.
75+ // This is memory-efficient as it's local to the sequence.
76+ let mut sequence_kmer_counts: FxHashMap < usize , f32 > = FxHashMap :: default ( ) ;
77+ for ( _, bitkmer_tuple, _) in seq. as_bytes ( ) . bit_kmers ( k as u8 , true ) {
78+ let kmer_hash = bitkmer_tuple. 0 ;
79+ if let Some ( & idx) = self . vocabulary . get ( & kmer_hash) {
80+ * sequence_kmer_counts. entry ( idx) . or_insert ( 0.0 ) += 1.0 ;
8081 }
8182 }
82- counts
83+ // Convert the map to a vector of (index, value) tuples.
84+ // Sorting is good practice for some sparse matrix formats.
85+ let mut features: Vec < ( usize , f32 ) > = sequence_kmer_counts. into_iter ( ) . collect ( ) ;
86+ features. sort_unstable_by_key ( |& ( idx, _) | idx) ;
87+ features
8388 } )
8489 . collect ( )
8590 }
8691}
8792
88- /// Encodes class labels (strings) into integers .
93+ /// Encodes string labels into integer representations and vice-versa .
8994#[ derive( Serialize , Deserialize , Debug ) ]
9095pub struct LabelEncoder {
91- pub label_to_int : HashMap < String , usize > ,
96+ pub label_to_int : FxHashMap < String , usize > ,
9297 pub int_to_label : Vec < String > ,
9398}
9499
95100impl LabelEncoder {
101+ /// Creates a new, empty `LabelEncoder`.
96102 pub fn new ( ) -> Self {
97103 Self {
98- label_to_int : HashMap :: new ( ) ,
104+ label_to_int : FxHashMap :: default ( ) ,
99105 int_to_label : Vec :: new ( ) ,
100106 }
101107 }
102108
103- /// Learns the label mapping from a collection of labels.
104- pub fn fit < T : AsRef < str > > ( & mut self , labels : & [ T ] ) {
109+ /// Learns the mapping from a slice of string labels.
110+ pub fn fit < T : AsRef < str > + std :: hash :: Hash + std :: cmp :: Eq > ( & mut self , labels : & [ T ] ) {
105111 for label in labels {
106- let label_str = label. as_ref ( ) ;
107- if ! self . label_to_int . contains_key ( label_str) {
112+ let label_str = label. as_ref ( ) . to_string ( ) ;
113+ self . label_to_int . entry ( label_str. clone ( ) ) . or_insert_with ( || {
108114 let index = self . int_to_label . len ( ) ;
109- self . label_to_int . insert ( label_str. to_string ( ) , index ) ;
110- self . int_to_label . push ( label_str . to_string ( ) ) ;
111- }
115+ self . int_to_label . push ( label_str) ;
116+ index
117+ } ) ;
112118 }
113119 }
114120
115- /// Transforms a collection of labels into their integer representations.
121+ /// Transforms a slice of string labels into their integer representations.
116122 pub fn transform < T : AsRef < str > > ( & self , labels : & [ T ] ) -> Vec < usize > {
117123 labels
118124 . iter ( )
119- . map ( |label| * self . label_to_int . get ( label. as_ref ( ) ) . unwrap ( ) )
125+ . map ( |label| {
126+ * self . label_to_int . get ( label. as_ref ( ) )
127+ . expect ( "Label not found in encoder. `fit` must be called first with all possible labels." )
128+ } )
120129 . collect ( )
121130 }
122- }
123-
124- // --- Utility Functions ---
125-
126- /// Generates a space-separated string of k-mers from a DNA sequence.
127- ///
128- /// # Arguments
129- /// * `sequence` - The input DNA sequence.
130- /// * `k` - The k-mer size.
131- ///
132- /// # Returns
133- /// A `String` of k-mers. Returns an empty string if the sequence is shorter than `k`.
134- pub fn kmerize ( sequence : & str , k : usize ) -> String {
135- if sequence. len ( ) < k {
136- return String :: new ( ) ;
137- }
138- ( 0 ..=sequence. len ( ) - k)
139- . map ( |i| & sequence[ i..i + k] )
140- . collect :: < Vec < & str > > ( )
141- . join ( " " )
142- }
131+ }
0 commit comments