Skip to content

Commit 6ddf9ae

Browse files
committed
Add AdamW Optimizer
1 parent a84b940 commit 6ddf9ae

File tree

3 files changed

+181
-2
lines changed

3 files changed

+181
-2
lines changed

DIRECTORY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@
237237
* [Negative Log Likelihood](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/negative_log_likelihood.rs)
238238
* Optimization
239239
* [Adam](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/optimization/adam.rs)
240+
* [AdamW](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/optimization/adamw.rs)
240241
* [Gradient Descent](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/optimization/gradient_descent.rs)
241242
* [Momentum](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/optimization/momentum.rs)
242243
* Math
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
//! # AdamW (Adam with decoupled weight decay) optimizer
2+
//!
3+
//! AdamW modifies the standard Adam optimizer by decoupling weight decay from the
4+
//! gradient update step. In standard Adam, weight decay is typically implemented
5+
//! by adding an L2 penalty to the loss, which interacts with the adaptive learning
6+
//! rates in a way that often results in suboptimal model convergence.
7+
//!
8+
//! AdamW explicitly decays the weights prior to the gradient update, restoring
9+
//! the original mathematical definition of weight decay and generally enabling
10+
//! better performance on complex models such as transformers.
11+
//!
12+
//! ## Resources:
13+
//! - Decoupled Weight Decay Regularization (by Ilya Loshchilov and Frank Hutter):
14+
//! - [https://arxiv.org/abs/1711.05101]
15+
//! - PyTorch AdamW optimizer:
16+
//! - [https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html]
17+
18+
#[allow(dead_code)]
19+
pub struct AdamW {
20+
learning_rate: f64, // alpha: initial step size
21+
betas: (f64, f64), // betas: exponential decay rates for moment estimates
22+
epsilon: f64, // epsilon: prevent division by zero
23+
weight_decay: f64, // weight_decay: decouples weight decay penalty
24+
m: Vec<f64>, // m: biased first moment estimate of gradient
25+
v: Vec<f64>, // v: biased second raw moment estimate of gradient
26+
t: usize, // t: time step
27+
}
28+
29+
#[allow(dead_code)]
30+
impl AdamW {
31+
pub fn new(
32+
learning_rate: Option<f64>,
33+
betas: Option<(f64, f64)>,
34+
epsilon: Option<f64>,
35+
weight_decay: Option<f64>,
36+
params_len: usize,
37+
) -> Self {
38+
AdamW {
39+
learning_rate: learning_rate.unwrap_or(1e-3),
40+
betas: betas.unwrap_or((0.9, 0.999)),
41+
epsilon: epsilon.unwrap_or(1e-8),
42+
weight_decay: weight_decay.unwrap_or(1e-2), // default weight decay scaling
43+
m: vec![0.0; params_len],
44+
v: vec![0.0; params_len],
45+
t: 0,
46+
}
47+
}
48+
49+
/// Computes the AdamW step, updating the model parameters directly inline to
50+
/// properly enable decoupled weight decay modifications.
51+
pub fn step(&mut self, params: &mut [f64], gradients: &[f64]) {
52+
assert_eq!(
53+
params.len(),
54+
gradients.len(),
55+
"Parameters and gradients must be identical sizes."
56+
);
57+
self.t += 1;
58+
59+
for i in 0..gradients.len() {
60+
// Apply decoupled weight decay (the 'W' in AdamW) inline
61+
params[i] -= self.learning_rate * self.weight_decay * params[i];
62+
63+
// update biased first and second moment estimate
64+
self.m[i] = self.betas.0 * self.m[i] + (1.0 - self.betas.0) * gradients[i];
65+
self.v[i] = self.betas.1 * self.v[i] + (1.0 - self.betas.1) * gradients[i].powi(2);
66+
67+
// bias correction
68+
let m_hat = self.m[i] / (1.0 - self.betas.0.powi(self.t as i32));
69+
let v_hat = self.v[i] / (1.0 - self.betas.1.powi(self.t as i32));
70+
71+
// Apply standard Adam adaptive learning rate step
72+
params[i] -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
73+
}
74+
}
75+
}
76+
77+
#[cfg(test)]
78+
mod tests {
79+
use super::*;
80+
81+
#[test]
82+
fn test_adamw_init_default_values() {
83+
let optimizer = AdamW::new(None, None, None, None, 1);
84+
85+
assert_eq!(optimizer.learning_rate, 0.001);
86+
assert_eq!(optimizer.betas, (0.9, 0.999));
87+
assert_eq!(optimizer.epsilon, 1e-8);
88+
assert_eq!(optimizer.weight_decay, 1e-2);
89+
assert_eq!(optimizer.m, vec![0.0; 1]);
90+
assert_eq!(optimizer.v, vec![0.0; 1]);
91+
assert_eq!(optimizer.t, 0);
92+
}
93+
94+
#[test]
95+
fn test_adamw_init_custom_values() {
96+
let optimizer = AdamW::new(Some(0.1), Some((0.8, 0.888)), Some(1e-4), Some(0.005), 3);
97+
98+
assert_eq!(optimizer.learning_rate, 0.1);
99+
assert_eq!(optimizer.betas, (0.8, 0.888));
100+
assert_eq!(optimizer.epsilon, 1e-4);
101+
assert_eq!(optimizer.weight_decay, 0.005);
102+
assert_eq!(optimizer.m, vec![0.0; 3]);
103+
assert_eq!(optimizer.v, vec![0.0; 3]);
104+
assert_eq!(optimizer.t, 0);
105+
}
106+
107+
#[test]
108+
fn test_adamw_step_default_params() {
109+
let gradients = vec![-1.0, 2.0, -3.0];
110+
let mut params = vec![0.5, -0.5, 0.0]; // non-zero starting params to test wd
111+
112+
let mut optimizer = AdamW::new(None, None, None, None, 3);
113+
optimizer.step(&mut params, &gradients);
114+
115+
// Calculate expected values conceptually manually
116+
// For i=0 (val = 0.5, grad = -1.0)
117+
// param = 0.5 - (0.001 * 0.01 * 0.5) = 0.5 - 0.000005 = 0.499995
118+
// m = 0.9(0) + 0.1(-1.0) = -0.1
119+
// v = 0.999(0) + 0.001(1.0) = 0.001
120+
// m_hat = -0.1 / 0.1 = -1.0
121+
// v_hat = 0.001 / 0.001 = 1.0
122+
// param -= 0.001 * -1.0 / (1.0 + 1e-8)
123+
// final param roughly 0.499995 + 0.001 = 0.50099499999
124+
assert!(params[0] > 0.5);
125+
assert!(params[1] < -0.5);
126+
}
127+
128+
#[test]
129+
fn test_adamw_step_zero_gradients_with_weight_decay() {
130+
// If gradients are zero, params should strictly decay toward zero.
131+
let gradients = vec![0.0, 0.0];
132+
let mut params = vec![100.0, -100.0];
133+
134+
let mut optimizer = AdamW::new(Some(1.0), None, None, Some(0.1), 2); // 10% daily decay
135+
optimizer.step(&mut params, &gradients);
136+
137+
assert_eq!(params, vec![90.0, -90.0]); // 10% toward 0
138+
optimizer.step(&mut params, &gradients);
139+
assert_eq!(params, vec![81.0, -81.0]);
140+
}
141+
142+
#[ignore]
143+
#[test]
144+
fn test_adamw_step_iteratively_until_convergence() {
145+
const CONVERGENCE_THRESHOLD: f64 = 1e-4;
146+
let gradients = vec![1.0, 2.0, 3.0, 4.0];
147+
148+
let mut optimizer = AdamW::new(Some(0.01), None, None, Some(1e-4), 4);
149+
let mut model_params = vec![5.0; 4];
150+
151+
let mut updates_made = true;
152+
let mut loops = 0;
153+
154+
while updates_made && loops < 1000 {
155+
let old_params = model_params.clone();
156+
optimizer.step(&mut model_params, &gradients);
157+
158+
let mut diff = 0.0;
159+
for i in 0..model_params.len() {
160+
diff += (old_params[i] - model_params[i]).powi(2);
161+
}
162+
if diff.sqrt() < CONVERGENCE_THRESHOLD {
163+
updates_made = false;
164+
}
165+
loops += 1;
166+
}
167+
168+
assert!(
169+
loops < 1000,
170+
"Optimizer failed to converge within 1000 epochs."
171+
);
172+
173+
// Because the gradient is constantly pushing against it, AdamW will find an equilibrium point
174+
// balancing the gradient direction with the weight decay pressure.
175+
assert!(model_params[0] < 5.0);
176+
}
177+
}
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod adam;
2-
mod gradient_descent;
3-
mod momentum;
2+
pub mod adamw;
3+
pub mod gradient_descent;
4+
pub mod momentum;
45

56
pub use self::adam::Adam;
67
pub use self::gradient_descent::gradient_descent;

0 commit comments

Comments
 (0)