Skip to content

Commit 3d0f91c

Browse files
feat(api): add initial MLP Logic with more to come, fix(types): fix types not being right
1 parent ebddd1d commit 3d0f91c

1 file changed

Lines changed: 84 additions & 4 deletions

File tree

  • try1 (OOP Approach)/rust/src

try1 (OOP Approach)/rust/src/mlp.rs

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use color_print::{cformat, cprintln};
12
use rand::Rng;
23

34
fn random_float(rand_range: u8) -> f32 {
@@ -6,13 +7,16 @@ fn random_float(rand_range: u8) -> f32 {
67

78
(random_val * (rand_range as f32) * 2.0) - 1.0
89
}
10+
fn sigmoid(x: f32) -> f32 {
11+
1.0 / (1.0 + (-x).exp())
12+
}
913
struct Neuron {
1014
pub value: f32,
1115
pub bias: f32,
1216
pub weights: Vec<f32>,
1317
}
1418
impl Neuron {
15-
fn new(prev_layer_neurons_count: u8) -> Neuron {
19+
fn new(prev_layer_neurons_count: u16) -> Neuron {
1620
let rand_range: u8 = 1;
1721
let mut weights: Vec<f32> = Vec::new();
1822
for _ in 0..prev_layer_neurons_count {
@@ -28,15 +32,19 @@ impl Neuron {
2832
}
2933
pub struct Layer {
3034
neurons: Vec<Neuron>,
35+
prev_layer_neurons_count: u16,
3136
}
3237
impl Layer {
33-
pub fn new(size: u16, prev_layer_neurons_count: u8) -> Layer {
38+
pub fn new(neurons_count: u16, prev_layer_neurons_count: u16) -> Layer {
3439
let mut neurons: Vec<Neuron> = Vec::new();
35-
for _ in 0..size {
40+
for _ in 0..neurons_count {
3641
neurons.push(Neuron::new(prev_layer_neurons_count));
3742
}
3843

39-
Layer { neurons }
44+
Layer {
45+
neurons,
46+
prev_layer_neurons_count,
47+
}
4048
}
4149
pub fn show_neurons(&self) {
4250
for (i, neuron) in self.neurons.iter().enumerate() {
@@ -52,3 +60,75 @@ impl Layer {
5260
}
5361
}
5462
// TODO: Add remaining Structs like MLP,
63+
pub struct Mlp {
64+
hid_out_layers: Vec<Layer>,
65+
input_layer_size: u16,
66+
lrate: f32,
67+
}
68+
impl Mlp {
69+
pub fn new(input_layer_size: u16, hid_out_layers_sizes: &Vec<u16>, lrate: f32) -> Mlp {
70+
let mut hid_out_layers: Vec<Layer> = Vec::new();
71+
let mut prev_layer_neurons_count: u16 = input_layer_size;
72+
for size in hid_out_layers_sizes.iter() {
73+
hid_out_layers.push(Layer::new(*size, prev_layer_neurons_count));
74+
prev_layer_neurons_count = *size;
75+
}
76+
Mlp {
77+
hid_out_layers,
78+
input_layer_size,
79+
lrate,
80+
}
81+
}
82+
pub fn describe(&self) {
83+
println!();
84+
cprintln!("<green>+------------------------------------+</>");
85+
cprintln!("<green> Multi Layer Perceptron </>");
86+
cprintln!("<green>+------------------------------------+</>");
87+
88+
// +1 to also consider the input layer
89+
println!("Layer Count : {}", self.hid_out_layers.len() + 1);
90+
cprintln!("<cyan>Layer Sizes: </>");
91+
92+
print!("{} | ", self.input_layer_size);
93+
for layer in self.hid_out_layers.iter() {
94+
print!("{} | ", layer.neurons.len());
95+
}
96+
println!();
97+
}
98+
pub fn reset_neurons_activations(&mut self) {
99+
for layer in self.hid_out_layers.iter_mut() {
100+
for neuron in layer.neurons.iter_mut() {
101+
neuron.value = 0.0;
102+
}
103+
}
104+
}
105+
pub fn feed_forward(&mut self, inputs: &Vec<f32>) -> Vec<f32> {
106+
self.reset_neurons_activations();
107+
if inputs.len() != self.input_layer_size as usize {
108+
panic!("Expected Input wasn't Received");
109+
}
110+
let mut prev_layer = &mut Layer::new(self.input_layer_size, 0);
111+
for (i, input) in inputs.iter().enumerate() {
112+
prev_layer.neurons[i].value = *input;
113+
}
114+
// For traversing each layer
115+
for layer in self.hid_out_layers.iter_mut() {
116+
// for traversing each neruons of the layer
117+
for neuron in layer.neurons.iter_mut() {
118+
let mut weighted_sum: f32 = 0.0;
119+
for (i, weight) in neuron.weights.iter().enumerate() {
120+
weighted_sum += prev_layer.neurons[i].value * weight;
121+
}
122+
neuron.value += sigmoid(weighted_sum) + neuron.bias;
123+
}
124+
prev_layer = layer;
125+
}
126+
127+
// For Returning Output
128+
let mut output: Vec<f32> = Vec::new();
129+
for neuron in prev_layer.neurons.iter() {
130+
output.push(neuron.value);
131+
}
132+
output
133+
}
134+
}

0 commit comments

Comments
 (0)