Skip to content

Commit d2e1907

Browse files
add naive-bayes
1 parent 5a4e21f commit d2e1907

3 files changed

Lines changed: 294 additions & 0 deletions

File tree

DIRECTORY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@
207207
* [K-Nearest Neighbors](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_nearest_neighbors.rs)
208208
* [Linear Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/linear_regression.rs)
209209
* [Logistic Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/logistic_regression.rs)
210+
* [Naive Bayes](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/naive_bayes.rs)
210211
* Loss Function
211212
* [Average Margin Ranking Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/average_margin_ranking_loss.rs)
212213
* [Hinge Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/hinge_loss.rs)

src/machine_learning/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod k_nearest_neighbors;
44
mod linear_regression;
55
mod logistic_regression;
66
mod loss_function;
7+
mod naive_bayes;
78
mod optimization;
89

910
pub use self::cholesky::cholesky;
@@ -18,5 +19,6 @@ pub use self::loss_function::kld_loss;
1819
pub use self::loss_function::mae_loss;
1920
pub use self::loss_function::mse_loss;
2021
pub use self::loss_function::neg_log_likelihood;
22+
pub use self::naive_bayes::naive_bayes;
2123
pub use self::optimization::gradient_descent;
2224
pub use self::optimization::Adam;
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
/// Naive Bayes classifier for classification tasks.
2+
/// This implementation uses Gaussian Naive Bayes, which assumes that
3+
/// features follow a normal (Gaussian) distribution.
4+
/// The algorithm calculates class priors and feature statistics (mean and variance)
5+
/// for each class, then uses Bayes' theorem to predict class probabilities.
6+
7+
pub struct ClassStatistics {
8+
pub class_label: f64,
9+
pub prior: f64,
10+
pub feature_means: Vec<f64>,
11+
pub feature_variances: Vec<f64>,
12+
}
13+
14+
fn calculate_class_statistics(
15+
training_data: &[(Vec<f64>, f64)],
16+
class_label: f64,
17+
num_features: usize,
18+
) -> Option<ClassStatistics> {
19+
let class_samples: Vec<&(Vec<f64>, f64)> = training_data
20+
.iter()
21+
.filter(|(_, label)| (*label - class_label).abs() < 1e-10)
22+
.collect();
23+
24+
if class_samples.is_empty() {
25+
return None;
26+
}
27+
28+
let prior = class_samples.len() as f64 / training_data.len() as f64;
29+
30+
let mut feature_means = vec![0.0; num_features];
31+
let mut feature_variances = vec![0.0; num_features];
32+
33+
// Calculate means
34+
for (features, _) in &class_samples {
35+
for (i, &feature) in features.iter().enumerate() {
36+
if i < num_features {
37+
feature_means[i] += feature;
38+
}
39+
}
40+
}
41+
42+
let n = class_samples.len() as f64;
43+
for mean in &mut feature_means {
44+
*mean /= n;
45+
}
46+
47+
// Calculate variances
48+
for (features, _) in &class_samples {
49+
for (i, &feature) in features.iter().enumerate() {
50+
if i < num_features {
51+
let diff = feature - feature_means[i];
52+
feature_variances[i] += diff * diff;
53+
}
54+
}
55+
}
56+
57+
let epsilon = 1e-9;
58+
for variance in &mut feature_variances {
59+
*variance = (*variance / n).max(epsilon);
60+
}
61+
62+
Some(ClassStatistics {
63+
class_label,
64+
prior,
65+
feature_means,
66+
feature_variances,
67+
})
68+
}
69+
70+
fn gaussian_log_pdf(x: f64, mean: f64, variance: f64) -> f64 {
71+
let diff = x - mean;
72+
let exponent_term = -(diff * diff) / (2.0 * variance);
73+
let log_coefficient = -0.5 * (2.0 * std::f64::consts::PI * variance).ln();
74+
log_coefficient + exponent_term
75+
}
76+
77+
78+
pub fn train_naive_bayes(training_data: Vec<(Vec<f64>, f64)>) -> Option<Vec<ClassStatistics>> {
79+
if training_data.is_empty() {
80+
return None;
81+
}
82+
83+
let num_features = training_data[0].0.len();
84+
if num_features == 0 {
85+
return None;
86+
}
87+
88+
// Verify all samples have the same number of features
89+
if !training_data
90+
.iter()
91+
.all(|(features, _)| features.len() == num_features)
92+
{
93+
return None;
94+
}
95+
96+
// Get unique class labels
97+
let mut unique_classes = Vec::new();
98+
for (_, label) in &training_data {
99+
if !unique_classes
100+
.iter()
101+
.any(|&c: &f64| (c - *label).abs() < 1e-10)
102+
{
103+
unique_classes.push(*label);
104+
}
105+
}
106+
107+
let mut class_stats = Vec::new();
108+
109+
for class_label in unique_classes {
110+
if let Some(mut stats) =
111+
calculate_class_statistics(&training_data, class_label, num_features)
112+
{
113+
stats.class_label = class_label;
114+
class_stats.push(stats);
115+
}
116+
}
117+
118+
if class_stats.is_empty() {
119+
return None;
120+
}
121+
122+
Some(class_stats)
123+
}
124+
125+
126+
pub fn predict_naive_bayes(model: &[ClassStatistics], test_point: &[f64]) -> Option<f64> {
127+
if model.is_empty() || test_point.is_empty() {
128+
return None;
129+
}
130+
131+
// Get number of features from the first class statistics
132+
let num_features = model[0].feature_means.len();
133+
if test_point.len() != num_features {
134+
return None;
135+
}
136+
137+
let mut best_class = None;
138+
let mut best_log_prob = f64::NEG_INFINITY;
139+
140+
for stats in model {
141+
// Calculate log probability to avoid underflow
142+
let mut log_prob = stats.prior.ln();
143+
144+
for (i, &feature) in test_point.iter().enumerate() {
145+
if i < stats.feature_means.len() && i < stats.feature_variances.len() {
146+
// Use log PDF directly to avoid numerical underflow
147+
log_prob +=
148+
gaussian_log_pdf(feature, stats.feature_means[i], stats.feature_variances[i]);
149+
}
150+
}
151+
152+
if log_prob > best_log_prob {
153+
best_log_prob = log_prob;
154+
best_class = Some(stats.class_label);
155+
}
156+
}
157+
158+
best_class
159+
}
160+
161+
162+
pub fn naive_bayes(training_data: Vec<(Vec<f64>, f64)>, test_point: Vec<f64>) -> Option<f64> {
163+
let model = train_naive_bayes(training_data)?;
164+
predict_naive_bayes(&model, &test_point)
165+
}
166+
167+
#[cfg(test)]
168+
mod tests {
169+
use super::*;
170+
171+
#[test]
172+
fn test_naive_bayes_simple_classification() {
173+
let training_data = vec![
174+
(vec![1.0, 1.0], 0.0),
175+
(vec![1.1, 1.0], 0.0),
176+
(vec![1.0, 1.1], 0.0),
177+
(vec![5.0, 5.0], 1.0),
178+
(vec![5.1, 5.0], 1.0),
179+
(vec![5.0, 5.1], 1.0),
180+
];
181+
182+
// Test point closer to class 0
183+
let test_point = vec![1.05, 1.05];
184+
let result = naive_bayes(training_data.clone(), test_point);
185+
assert_eq!(result, Some(0.0));
186+
187+
// Test point closer to class 1
188+
let test_point = vec![5.05, 5.05];
189+
let result = naive_bayes(training_data, test_point);
190+
assert_eq!(result, Some(1.0));
191+
}
192+
193+
#[test]
194+
fn test_naive_bayes_one_dimensional() {
195+
let training_data = vec![
196+
(vec![1.0], 0.0),
197+
(vec![1.1], 0.0),
198+
(vec![1.2], 0.0),
199+
(vec![5.0], 1.0),
200+
(vec![5.1], 1.0),
201+
(vec![5.2], 1.0),
202+
];
203+
204+
let test_point = vec![1.15];
205+
let result = naive_bayes(training_data.clone(), test_point);
206+
assert_eq!(result, Some(0.0));
207+
208+
let test_point = vec![5.15];
209+
let result = naive_bayes(training_data, test_point);
210+
assert_eq!(result, Some(1.0));
211+
}
212+
213+
#[test]
214+
fn test_naive_bayes_empty_training_data() {
215+
let training_data = vec![];
216+
let test_point = vec![1.0, 2.0];
217+
let result = naive_bayes(training_data, test_point);
218+
assert_eq!(result, None);
219+
}
220+
221+
#[test]
222+
fn test_naive_bayes_empty_test_point() {
223+
let training_data = vec![(vec![1.0, 2.0], 0.0)];
224+
let test_point = vec![];
225+
let result = naive_bayes(training_data, test_point);
226+
assert_eq!(result, None);
227+
}
228+
229+
#[test]
230+
fn test_naive_bayes_dimension_mismatch() {
231+
let training_data = vec![(vec![1.0, 2.0], 0.0), (vec![3.0, 4.0], 1.0)];
232+
let test_point = vec![1.0]; // Wrong dimension
233+
let result = naive_bayes(training_data, test_point);
234+
assert_eq!(result, None);
235+
}
236+
237+
#[test]
238+
fn test_naive_bayes_inconsistent_feature_dimensions() {
239+
let training_data = vec![
240+
(vec![1.0, 2.0], 0.0),
241+
(vec![3.0], 1.0), // Different dimension
242+
];
243+
let test_point = vec![1.0, 2.0];
244+
let result = naive_bayes(training_data, test_point);
245+
assert_eq!(result, None);
246+
}
247+
248+
#[test]
249+
fn test_naive_bayes_multiple_classes() {
250+
let training_data = vec![
251+
(vec![1.0, 1.0], 0.0),
252+
(vec![1.1, 1.0], 0.0),
253+
(vec![5.0, 5.0], 1.0),
254+
(vec![5.1, 5.0], 1.0),
255+
(vec![9.0, 9.0], 2.0),
256+
(vec![9.1, 9.0], 2.0),
257+
];
258+
259+
let test_point = vec![1.05, 1.05];
260+
let result = naive_bayes(training_data.clone(), test_point);
261+
assert_eq!(result, Some(0.0));
262+
263+
let test_point = vec![5.05, 5.05];
264+
let result = naive_bayes(training_data.clone(), test_point);
265+
assert_eq!(result, Some(1.0));
266+
267+
let test_point = vec![9.05, 9.05];
268+
let result = naive_bayes(training_data, test_point);
269+
assert_eq!(result, Some(2.0));
270+
}
271+
272+
#[test]
273+
fn test_train_and_predict_separately() {
274+
let training_data = vec![
275+
(vec![1.0, 1.0], 0.0),
276+
(vec![1.1, 1.0], 0.0),
277+
(vec![5.0, 5.0], 1.0),
278+
(vec![5.1, 5.0], 1.0),
279+
];
280+
281+
let model = train_naive_bayes(training_data);
282+
assert!(model.is_some());
283+
284+
let model = model.unwrap();
285+
assert_eq!(model.len(), 2);
286+
287+
let test_point = vec![1.05, 1.05];
288+
let result = predict_naive_bayes(&model, &test_point);
289+
assert_eq!(result, Some(0.0));
290+
}
291+
}

0 commit comments

Comments
 (0)