Skip to content

Commit 33aae10

Browse files
committed
added momentum.rs optimizer in machine learning
1 parent a8491ae commit 33aae10

2 files changed

Lines changed: 145 additions & 0 deletions

File tree

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
mod adam;
22
mod gradient_descent;
3+
mod momentum;
34

45
pub use self::adam::Adam;
56
pub use self::gradient_descent::gradient_descent;
7+
pub use self::momentum::momentum;
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/// Momentum Optimization
2+
///
3+
/// Momentum is an extension of gradient descent that accelerates convergence by accumulating
4+
/// a velocity vector in directions of persistent reduction in the objective function.
5+
/// This helps the optimizer navigate ravines and avoid getting stuck in local minima.
6+
///
7+
/// The algorithm maintains a velocity vector that accumulates exponentially decaying moving
8+
/// averages of past gradients. This allows the optimizer to build up speed in consistent
9+
/// directions while dampening oscillations.
10+
///
11+
/// The update equations are:
12+
/// velocity_{k+1} = beta * velocity_k + gradient_of_function(x_k)
13+
/// x_{k+1} = x_k - learning_rate * velocity_{k+1}
14+
///
15+
/// where beta (typically 0.9) controls how much past gradients influence the current update.
16+
///
17+
/// # Arguments
18+
///
19+
/// * `derivative_fn` - The function that calculates the gradient of the objective function at a given point.
20+
/// * `x` - The initial parameter vector to be optimized.
21+
/// * `learning_rate` - Step size for each iteration.
22+
/// * `beta` - Momentum coefficient (typically 0.9). Higher values give more weight to past gradients.
23+
/// * `num_iterations` - The number of iterations to run the optimization.
24+
///
25+
/// # Returns
26+
///
27+
/// A reference to the optimized parameter vector `x`.
28+
pub fn momentum(
29+
derivative: impl Fn(&[f64]) -> Vec<f64>,
30+
x: &mut Vec<f64>,
31+
learning_rate: f64,
32+
beta: f64,
33+
num_iterations: i32,
34+
) -> &mut Vec<f64> {
35+
// Initialize velocity vector to zero
36+
let mut velocity: Vec<f64> = vec![0.0; x.len()];
37+
38+
for _ in 0..num_iterations {
39+
let gradient = derivative(x);
40+
41+
// Update velocity and parameters
42+
for ((x_k, vel), grad) in x.iter_mut().zip(velocity.iter_mut()).zip(gradient.iter()) {
43+
*vel = beta * *vel + grad;
44+
*x_k -= learning_rate * *vel;
45+
}
46+
}
47+
x
48+
}
49+
50+
#[cfg(test)]
51+
mod test {
52+
use super::*;
53+
54+
#[test]
55+
fn test_momentum_optimized() {
56+
fn derivative_of_square(params: &[f64]) -> Vec<f64> {
57+
params.iter().map(|x| 2.0 * x).collect()
58+
}
59+
60+
let mut x: Vec<f64> = vec![5.0, 6.0];
61+
let learning_rate: f64 = 0.01;
62+
let beta: f64 = 0.9;
63+
let num_iterations: i32 = 1000;
64+
65+
let minimized_vector = momentum(
66+
derivative_of_square,
67+
&mut x,
68+
learning_rate,
69+
beta,
70+
num_iterations,
71+
);
72+
73+
let test_vector = [0.0, 0.0];
74+
let tolerance = 1e-6;
75+
76+
for (minimized_value, test_value) in minimized_vector.iter().zip(test_vector.iter()) {
77+
assert!((minimized_value - test_value).abs() < tolerance);
78+
}
79+
}
80+
81+
#[test]
82+
fn test_momentum_unoptimized() {
83+
fn derivative_of_square(params: &[f64]) -> Vec<f64> {
84+
params.iter().map(|x| 2.0 * x).collect()
85+
}
86+
87+
let mut x: Vec<f64> = vec![5.0, 6.0];
88+
let learning_rate: f64 = 0.01;
89+
let beta: f64 = 0.9;
90+
let num_iterations: i32 = 10;
91+
92+
let minimized_vector = momentum(
93+
derivative_of_square,
94+
&mut x,
95+
learning_rate,
96+
beta,
97+
num_iterations,
98+
);
99+
100+
let test_vector = [0.0, 0.0];
101+
let tolerance = 1e-6;
102+
103+
for (minimized_value, test_value) in minimized_vector.iter().zip(test_vector.iter()) {
104+
assert!((minimized_value - test_value).abs() >= tolerance);
105+
}
106+
}
107+
108+
#[test]
109+
fn test_momentum_faster_than_gd() {
110+
fn derivative_of_square(params: &[f64]) -> Vec<f64> {
111+
params.iter().map(|x| 2.0 * x).collect()
112+
}
113+
114+
// Test that momentum converges faster than gradient descent
115+
let mut x_momentum: Vec<f64> = vec![5.0, 6.0];
116+
let mut x_gd: Vec<f64> = vec![5.0, 6.0];
117+
let learning_rate: f64 = 0.01;
118+
let beta: f64 = 0.9;
119+
let num_iterations: i32 = 50;
120+
121+
momentum(
122+
derivative_of_square,
123+
&mut x_momentum,
124+
learning_rate,
125+
beta,
126+
num_iterations,
127+
);
128+
129+
// Gradient descent from your original implementation
130+
for _ in 0..num_iterations {
131+
let gradient = derivative_of_square(&x_gd);
132+
for (x_k, grad) in x_gd.iter_mut().zip(gradient.iter()) {
133+
*x_k -= learning_rate * grad;
134+
}
135+
}
136+
137+
// Momentum should be closer to zero
138+
let momentum_distance: f64 = x_momentum.iter().map(|x| x * x).sum();
139+
let gd_distance: f64 = x_gd.iter().map(|x| x * x).sum();
140+
141+
assert!(momentum_distance < gd_distance);
142+
}
143+
}

0 commit comments

Comments
 (0)