Skip to content

Commit 394c852

Browse files
feat(api): add Predict method to MLP Struct
1 parent 679ee47 commit 394c852

1 file changed

Lines changed: 37 additions & 1 deletion

File tree

  • try1 (OOP Approach)/rust/src

try1 (OOP Approach)/rust/src/mlp.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ impl Mlp {
102102
}
103103
}
104104
}
105-
pub fn feed_forward(&mut self, inputs: &Vec<f32>) -> Vec<f32> {
105+
pub fn feed_forward(&mut self, inputs: &[f32]) -> Vec<f32> {
106106
self.reset_neurons_activations();
107107
if inputs.len() != self.input_layer_size as usize {
108108
panic!("Expected Input wasn't Received");
@@ -131,4 +131,40 @@ impl Mlp {
131131
}
132132
output
133133
}
134+
pub fn predict(&mut self, inputs: &[Vec<f32>], targets: &[Vec<f32>]) {
135+
let mut accuracy: f32 = 0.0;
136+
for (input_sample, target_sample) in inputs.iter().zip(targets.iter()) {
137+
println!("Inputs : ");
138+
for input_sample_feature in input_sample.iter() {
139+
print!("{:.2}, ", input_sample_feature);
140+
}
141+
142+
println!("Outputs : ");
143+
144+
//Get the highest output
145+
let outputs = self.feed_forward(input_sample);
146+
let mut max: f32 = 0.0;
147+
let mut max_index: u8 = 0;
148+
for (i, output) in outputs.iter().enumerate() {
149+
if *output > max {
150+
max = *output;
151+
max_index = i as u8;
152+
}
153+
}
154+
155+
// Show Results from feedfoward with color and target
156+
for (i, output) in outputs.iter().enumerate() {
157+
if i == max_index as usize && target_sample[i] == 1.0 {
158+
cprintln!("<green>[{}] : {:.2} => {}</>", i, output, target_sample[i]);
159+
accuracy += 100.0 / inputs.len() as f32;
160+
} else if i == max_index as usize && target_sample[i] != 1.0 {
161+
cprintln!("<yellow>[{}] : {:.2} => {}</>", i, output, target_sample[i]);
162+
} else if i != max_index as usize {
163+
cprintln!("<red>[{}] : {:.2} => {}</>", i, output, target_sample[i]);
164+
}
165+
}
166+
}
167+
println!("Accuracy : {:.2}%", accuracy);
168+
}
169+
// TODO: Implement outher MLP class functions
134170
}

0 commit comments

Comments
 (0)