Skip to content

Commit a4ebae4

Browse files
authored
improve kmers
1 parent 7a81cc9 commit a4ebae4

3 files changed

Lines changed: 350 additions & 251 deletions

File tree

src/common.rs

Lines changed: 64 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,142 +1,131 @@
11
//! src/common.rs
22
//!
3-
//! Module for shared data structures and utility functions across the application.
4-
//!
5-
//! This module centralizes components like the model bundle, vectorizer,
6-
//! and label encoder to avoid code duplication between the `train` and `predict`
7-
//! modules.
3+
//! This module provides shared, high-performance data structures and utility
4+
//! functions for the `pathotypr` application. It has been optimized to reduce
5+
//! memory allocations and improve processing speed, especially for the `train`
6+
//! and `predict` subcommands.
87
8+
// --- Crates for performance and serialization ---
99
use rayon::prelude::*;
10+
use rustc_hash::FxHashMap; // A much faster hasher for integer keys
1011
use serde::{Deserialize, Serialize};
11-
// We only need the Decision Tree, not the whole Random Forest
12-
use smartcore::tree::decision_tree_classifier::DecisionTreeClassifier;
12+
use needletail::Sequence; // Import the Sequence trait to use its methods
13+
14+
// --- SmartCore components for machine learning ---
15+
// MODIFIED: All types now use f32 for memory efficiency.
1316
use smartcore::linalg::basic::matrix::DenseMatrix;
14-
use std::collections::HashMap;
17+
use smartcore::tree::decision_tree_classifier::DecisionTreeClassifier;
1518

1619
// --- Model Bundle Structs ---
1720

18-
/// Configuration for the model, such as the k-mer size.
21+
/// Defines the configuration of a trained model.
1922
#[derive(Serialize, Deserialize, Debug)]
2023
pub struct ModelConfig {
24+
pub pathotypr_version: String,
2125
pub kmer_size: usize,
22-
// Add the number of trees to the config
2326
pub n_trees: u16,
2427
}
2528

26-
/// A unified bundle containing everything needed for prediction.
29+
/// A unified, compressed bundle containing everything needed for prediction.
2730
#[derive(Serialize, Deserialize, Debug)]
2831
pub struct ModelBundle {
2932
pub config: ModelConfig,
3033
pub vectorizer: CountVectorizer,
3134
pub label_encoder: LabelEncoder,
32-
// CHANGE: We now store a vector of individual decision trees
33-
pub trees: Vec<DecisionTreeClassifier<f64, usize, DenseMatrix<f64>, Vec<usize>>>,
35+
// MODIFIED: The DecisionTreeClassifier now uses f32.
36+
pub trees: Vec<DecisionTreeClassifier<f32, usize, DenseMatrix<f32>, Vec<usize>>>,
3437
}
3538

3639
// --- Feature Processing Components ---
3740

38-
/// Transforms text into k-mer count vectors.
41+
/// Transforms sequences into k-mer count vectors using `u64` representation.
3942
#[derive(Serialize, Deserialize, Debug)]
4043
pub struct CountVectorizer {
41-
pub vocabulary: HashMap<String, usize>,
42-
pub feature_names: Vec<String>,
44+
pub vocabulary: FxHashMap<u64, usize>,
45+
pub num_features: usize,
4346
}
4447

4548
impl CountVectorizer {
49+
/// Creates a new, empty `CountVectorizer`.
4650
pub fn new() -> Self {
4751
Self {
48-
vocabulary: HashMap::new(),
49-
feature_names: Vec::new(),
52+
vocabulary: FxHashMap::default(),
53+
num_features: 0,
5054
}
5155
}
5256

53-
/// Builds the vocabulary from a collection of texts.
54-
pub fn fit<T: AsRef<str>>(&mut self, texts: &[T]) {
55-
let mut freq: HashMap<String, usize> = HashMap::new();
56-
for text in texts {
57-
for token in text.as_ref().split_whitespace() {
58-
*freq.entry(token.to_string()).or_insert(0) += 1;
59-
}
57+
/// Builds the vocabulary from a pre-computed map of k-mer counts.
58+
pub fn fit(&mut self, kmer_counts: &FxHashMap<u64, u32>) {
59+
let mut vocab_idx = 0;
60+
for &kmer_hash in kmer_counts.keys() {
61+
self.vocabulary.insert(kmer_hash, vocab_idx);
62+
vocab_idx += 1;
6063
}
61-
let mut freq_vec: Vec<(String, usize)> = freq.into_iter().collect();
62-
freq_vec.sort_by(|a, b| b.1.cmp(&a.1));
63-
self.vocabulary = freq_vec
64-
.iter()
65-
.enumerate()
66-
.map(|(i, (token, _))| (token.clone(), i))
67-
.collect();
68-
self.feature_names = freq_vec.into_iter().map(|(token, _)| token).collect();
64+
self.num_features = self.vocabulary.len();
6965
}
7066

71-
/// Transforms a collection of texts into a feature matrix.
72-
pub fn transform<T: AsRef<str> + Sync>(&self, texts: &[T]) -> Vec<Vec<f64>> {
73-
texts
67+
/// MODIFIED: Transforms sequences into a sparse data format using f32.
68+
/// Instead of a giant Vec<Vec<f64>>, this returns a Vec of sparse vectors.
69+
/// Each sparse vector is a Vec of (feature_index, count).
70+
pub fn transform_sparse(&self, sequences: &[String], k: usize) -> Vec<Vec<(usize, f32)>> {
71+
sequences
7472
.par_iter()
75-
.map(|text| {
76-
let mut counts = vec![0.0; self.vocabulary.len()];
77-
for token in text.as_ref().split_whitespace() {
78-
if let Some(&idx) = self.vocabulary.get(token) {
79-
counts[idx] += 1.0;
73+
.map(|seq| {
74+
// Use a temporary HashMap to count k-mers for this sequence only.
75+
// This is memory-efficient as it's local to the sequence.
76+
let mut sequence_kmer_counts: FxHashMap<usize, f32> = FxHashMap::default();
77+
for (_, bitkmer_tuple, _) in seq.as_bytes().bit_kmers(k as u8, true) {
78+
let kmer_hash = bitkmer_tuple.0;
79+
if let Some(&idx) = self.vocabulary.get(&kmer_hash) {
80+
*sequence_kmer_counts.entry(idx).or_insert(0.0) += 1.0;
8081
}
8182
}
82-
counts
83+
// Convert the map to a vector of (index, value) tuples.
84+
// Sorting is good practice for some sparse matrix formats.
85+
let mut features: Vec<(usize, f32)> = sequence_kmer_counts.into_iter().collect();
86+
features.sort_unstable_by_key(|&(idx, _)| idx);
87+
features
8388
})
8489
.collect()
8590
}
8691
}
8792

88-
/// Encodes class labels (strings) into integers.
93+
/// Encodes string labels into integer representations and vice-versa.
8994
#[derive(Serialize, Deserialize, Debug)]
9095
pub struct LabelEncoder {
91-
pub label_to_int: HashMap<String, usize>,
96+
pub label_to_int: FxHashMap<String, usize>,
9297
pub int_to_label: Vec<String>,
9398
}
9499

95100
impl LabelEncoder {
101+
/// Creates a new, empty `LabelEncoder`.
96102
pub fn new() -> Self {
97103
Self {
98-
label_to_int: HashMap::new(),
104+
label_to_int: FxHashMap::default(),
99105
int_to_label: Vec::new(),
100106
}
101107
}
102108

103-
/// Learns the label mapping from a collection of labels.
104-
pub fn fit<T: AsRef<str>>(&mut self, labels: &[T]) {
109+
/// Learns the mapping from a slice of string labels.
110+
pub fn fit<T: AsRef<str> + std::hash::Hash + std::cmp::Eq>(&mut self, labels: &[T]) {
105111
for label in labels {
106-
let label_str = label.as_ref();
107-
if !self.label_to_int.contains_key(label_str) {
112+
let label_str = label.as_ref().to_string();
113+
self.label_to_int.entry(label_str.clone()).or_insert_with(|| {
108114
let index = self.int_to_label.len();
109-
self.label_to_int.insert(label_str.to_string(), index);
110-
self.int_to_label.push(label_str.to_string());
111-
}
115+
self.int_to_label.push(label_str);
116+
index
117+
});
112118
}
113119
}
114120

115-
/// Transforms a collection of labels into their integer representations.
121+
/// Transforms a slice of string labels into their integer representations.
116122
pub fn transform<T: AsRef<str>>(&self, labels: &[T]) -> Vec<usize> {
117123
labels
118124
.iter()
119-
.map(|label| *self.label_to_int.get(label.as_ref()).unwrap())
125+
.map(|label| {
126+
*self.label_to_int.get(label.as_ref())
127+
.expect("Label not found in encoder. `fit` must be called first with all possible labels.")
128+
})
120129
.collect()
121130
}
122-
}
123-
124-
// --- Utility Functions ---
125-
126-
/// Generates a space-separated string of k-mers from a DNA sequence.
127-
///
128-
/// # Arguments
129-
/// * `sequence` - The input DNA sequence.
130-
/// * `k` - The k-mer size.
131-
///
132-
/// # Returns
133-
/// A `String` of k-mers. Returns an empty string if the sequence is shorter than `k`.
134-
pub fn kmerize(sequence: &str, k: usize) -> String {
135-
if sequence.len() < k {
136-
return String::new();
137-
}
138-
(0..=sequence.len() - k)
139-
.map(|i| &sequence[i..i + k])
140-
.collect::<Vec<&str>>()
141-
.join(" ")
142-
}
131+
}

0 commit comments

Comments
 (0)