Skip to content

Commit 3bccfa8

Browse files
feat: Add decoupled weight decay (AdamW) to Adam optimizer (#1037)
1 parent e08f5a5 commit 3bccfa8

1 file changed

Lines changed: 184 additions & 27 deletions

File tree

  • src/machine_learning/optimization

src/machine_learning/optimization/adam.rs

Lines changed: 184 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@
55
//! learning problems. Boasting memory-efficient fast convergence rates, it sets and iteratively
66
//! updates learning rates individually for each model parameter based on the gradient history.
77
//!
8+
//! Setting `weight_decay > 0.0` enables the AdamW variant (Loshchilov & Hutter, 2019), which
9+
//! applies weight decay directly to the parameters rather than folding it into the gradients.
10+
//! This keeps the decay rate constant and independent of the gradient history — the key flaw
11+
//! that AdamW corrects over naive L2 regularization inside Adam. With `weight_decay = 0.0`
12+
//! (the default), the two algorithms are identical.
13+
//!
814
//! ## Algorithm:
915
//!
1016
//! Given:
1117
//! - α is the learning rate
1218
//! - (β_1, β_2) are the exponential decay rates for moment estimates
1319
//! - ϵ is any small value to prevent division by zero
20+
//! - λ is the weight decay coefficient (0.0 for standard Adam, > 0.0 for AdamW)
1421
//! - g_t are the gradients at time step t
1522
//! - m_t are the biased first moment estimates of the gradient at time step t
1623
//! - v_t are the biased second raw moment estimates of the gradient at time step t
@@ -28,20 +35,25 @@
2835
//! while θ_t not converged do
2936
//! m_t = β_1 * m_{t−1} + (1 − β_1) * g_t
3037
//! v_t = β_2 * v_{t−1} + (1 − β_2) * g_t^2
31-
//! m_hat_t = m_t / 1 - β_1^t
32-
//! v_hat_t = v_t / 1 - β_2^t
33-
//! θ_t = θ_{t-1} − α * m_hat_t / (sqrt(v_hat_t) + ϵ)
38+
//! m_hat_t = m_t / (1 - β_1^t)
39+
//! v_hat_t = v_t / (1 - β_2^t)
40+
//! θ_t = θ_{t-1} − α * (m_hat_t / (sqrt(v_hat_t) + ϵ) + λ * θ_{t-1})
3441
//!
3542
//! ## Resources:
3643
//! - Adam: A Method for Stochastic Optimization (by Diederik P. Kingma and Jimmy Ba):
3744
//! - [https://arxiv.org/abs/1412.6980]
45+
//! - Decoupled Weight Decay Regularization (by Ilya Loshchilov and Frank Hutter):
46+
//! - [https://arxiv.org/abs/1711.05101]
3847
//! - PyTorch Adam optimizer:
39-
//! - [https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam]
48+
//! - [https://pytorch.org/docs/stable/generated/torch.optim.Adam.html]
49+
//! - PyTorch AdamW optimizer:
50+
//! - [https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html]
4051
//!
4152
pub struct Adam {
4253
learning_rate: f64, // alpha: initial step size for iterative optimization
4354
betas: (f64, f64), // betas: exponential decay rates for moment estimates
4455
epsilon: f64, // epsilon: prevent division by zero
56+
weight_decay: f64, // lambda: decoupled weight decay coefficient (0.0 = standard Adam)
4557
m: Vec<f64>, // m: biased first moment estimate of the gradient vector
4658
v: Vec<f64>, // v: biased second raw moment estimate of the gradient vector
4759
t: usize, // t: time step
@@ -52,20 +64,38 @@ impl Adam {
5264
learning_rate: Option<f64>,
5365
betas: Option<(f64, f64)>,
5466
epsilon: Option<f64>,
67+
weight_decay: Option<f64>,
5568
params_len: usize,
5669
) -> Self {
5770
Adam {
5871
learning_rate: learning_rate.unwrap_or(1e-3), // typical good default lr
5972
betas: betas.unwrap_or((0.9, 0.999)), // typical good default decay rates
6073
epsilon: epsilon.unwrap_or(1e-8), // typical good default epsilon
74+
weight_decay: weight_decay.unwrap_or(0.0), // 0.0 = standard Adam, > 0.0 = AdamW
6175
m: vec![0.0; params_len], // first moment vector elements all initialized to zero
6276
v: vec![0.0; params_len], // second moment vector elements all initialized to zero
6377
t: 0, // time step initialized to zero
6478
}
6579
}
6680

67-
pub fn step(&mut self, gradients: &[f64]) -> Vec<f64> {
68-
let mut model_params = vec![0.0; gradients.len()];
81+
/// Computes one update step.
82+
///
83+
/// `params` holds the current parameter values θ_{t-1}. When `weight_decay`
84+
/// is `0.0` the update is standard Adam; any positive value applies the AdamW
85+
/// decoupled decay term `λ * θ_{t-1}` directly to the parameters, independent
86+
/// of the adaptive scaling.
87+
///
88+
/// # Panics
89+
///
90+
/// Panics if `gradients` and `params` have different lengths.
91+
pub fn step(&mut self, gradients: &[f64], params: &[f64]) -> Vec<f64> {
92+
assert_eq!(
93+
gradients.len(),
94+
params.len(),
95+
"gradients and params must have the same length"
96+
);
97+
98+
let mut updated_params = vec![0.0; params.len()];
6999
self.t += 1;
70100

71101
for i in 0..gradients.len() {
@@ -77,83 +107,111 @@ impl Adam {
77107
let m_hat = self.m[i] / (1.0 - self.betas.0.powi(self.t as i32));
78108
let v_hat = self.v[i] / (1.0 - self.betas.1.powi(self.t as i32));
79109

80-
// update model parameters
81-
model_params[i] -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon);
110+
// Adaptive gradient step — preserves the original (lr * m_hat) / denom
111+
// operator order so floating-point results are identical to standard Adam
112+
// when weight_decay = 0.0. The decoupled decay term is added separately
113+
// so it does not interact with the adaptive scaling.
114+
updated_params[i] = params[i]
115+
- self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon)
116+
- self.learning_rate * self.weight_decay * params[i];
82117
}
83-
model_params // return updated model parameters
118+
updated_params // return updated model parameters
84119
}
85120
}
86121

87122
#[cfg(test)]
88123
mod tests {
89124
use super::*;
90125

126+
// ── Initialisation ────────────────────────────────────────────────────────
127+
91128
#[test]
92129
fn test_adam_init_default_values() {
93-
let optimizer = Adam::new(None, None, None, 1);
130+
let optimizer = Adam::new(None, None, None, None, 1);
94131

95132
assert_eq!(optimizer.learning_rate, 0.001);
96133
assert_eq!(optimizer.betas, (0.9, 0.999));
97134
assert_eq!(optimizer.epsilon, 1e-8);
135+
assert_eq!(optimizer.weight_decay, 0.0);
98136
assert_eq!(optimizer.m, vec![0.0; 1]);
99137
assert_eq!(optimizer.v, vec![0.0; 1]);
100138
assert_eq!(optimizer.t, 0);
101139
}
102140

103141
#[test]
104142
fn test_adam_init_custom_lr_value() {
105-
let optimizer = Adam::new(Some(0.9), None, None, 2);
143+
let optimizer = Adam::new(Some(0.9), None, None, None, 2);
106144

107145
assert_eq!(optimizer.learning_rate, 0.9);
108146
assert_eq!(optimizer.betas, (0.9, 0.999));
109147
assert_eq!(optimizer.epsilon, 1e-8);
148+
assert_eq!(optimizer.weight_decay, 0.0);
110149
assert_eq!(optimizer.m, vec![0.0; 2]);
111150
assert_eq!(optimizer.v, vec![0.0; 2]);
112151
assert_eq!(optimizer.t, 0);
113152
}
114153

115154
#[test]
116155
fn test_adam_init_custom_betas_value() {
117-
let optimizer = Adam::new(None, Some((0.8, 0.899)), None, 3);
156+
let optimizer = Adam::new(None, Some((0.8, 0.899)), None, None, 3);
118157

119158
assert_eq!(optimizer.learning_rate, 0.001);
120159
assert_eq!(optimizer.betas, (0.8, 0.899));
121160
assert_eq!(optimizer.epsilon, 1e-8);
161+
assert_eq!(optimizer.weight_decay, 0.0);
122162
assert_eq!(optimizer.m, vec![0.0; 3]);
123163
assert_eq!(optimizer.v, vec![0.0; 3]);
124164
assert_eq!(optimizer.t, 0);
125165
}
126166

127167
#[test]
128168
fn test_adam_init_custom_epsilon_value() {
129-
let optimizer = Adam::new(None, None, Some(1e-10), 4);
169+
let optimizer = Adam::new(None, None, Some(1e-10), None, 4);
130170

131171
assert_eq!(optimizer.learning_rate, 0.001);
132172
assert_eq!(optimizer.betas, (0.9, 0.999));
133173
assert_eq!(optimizer.epsilon, 1e-10);
174+
assert_eq!(optimizer.weight_decay, 0.0);
134175
assert_eq!(optimizer.m, vec![0.0; 4]);
135176
assert_eq!(optimizer.v, vec![0.0; 4]);
136177
assert_eq!(optimizer.t, 0);
137178
}
138179

180+
#[test]
181+
fn test_adam_init_custom_weight_decay_value() {
182+
let optimizer = Adam::new(None, None, None, Some(0.1), 3);
183+
184+
assert_eq!(optimizer.learning_rate, 0.001);
185+
assert_eq!(optimizer.betas, (0.9, 0.999));
186+
assert_eq!(optimizer.epsilon, 1e-8);
187+
assert_eq!(optimizer.weight_decay, 0.1);
188+
assert_eq!(optimizer.m, vec![0.0; 3]);
189+
assert_eq!(optimizer.v, vec![0.0; 3]);
190+
assert_eq!(optimizer.t, 0);
191+
}
192+
139193
#[test]
140194
fn test_adam_init_all_custom_values() {
141-
let optimizer = Adam::new(Some(1.0), Some((0.001, 0.099)), Some(1e-1), 5);
195+
let optimizer = Adam::new(Some(1.0), Some((0.001, 0.099)), Some(1e-1), Some(0.05), 5);
142196

143197
assert_eq!(optimizer.learning_rate, 1.0);
144198
assert_eq!(optimizer.betas, (0.001, 0.099));
145199
assert_eq!(optimizer.epsilon, 1e-1);
200+
assert_eq!(optimizer.weight_decay, 0.05);
146201
assert_eq!(optimizer.m, vec![0.0; 5]);
147202
assert_eq!(optimizer.v, vec![0.0; 5]);
148203
assert_eq!(optimizer.t, 0);
149204
}
150205

206+
// ── Step: standard Adam (weight_decay = 0.0) ──────────────────────────────
207+
151208
#[test]
152209
fn test_adam_step_default_params() {
153210
let gradients = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0];
211+
let params = vec![0.0; 8];
154212

155-
let mut optimizer = Adam::new(None, None, None, 8);
156-
let updated_params = optimizer.step(&gradients);
213+
let mut optimizer = Adam::new(None, None, None, None, 8);
214+
let updated_params = optimizer.step(&gradients, &params);
157215

158216
assert_eq!(
159217
updated_params,
@@ -173,9 +231,10 @@ mod tests {
173231
#[test]
174232
fn test_adam_step_custom_params() {
175233
let gradients = vec![9.0, -8.0, 7.0, -6.0, 5.0, -4.0, 3.0, -2.0, 1.0];
234+
let params = vec![0.0; 9];
176235

177-
let mut optimizer = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), 9);
178-
let updated_params = optimizer.step(&gradients);
236+
let mut optimizer = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), None, 9);
237+
let updated_params = optimizer.step(&gradients, &params);
179238

180239
assert_eq!(
181240
updated_params,
@@ -195,24 +254,93 @@ mod tests {
195254

196255
#[test]
197256
fn test_adam_step_empty_gradients_array() {
198-
let gradients = vec![];
257+
let gradients: Vec<f64> = vec![];
258+
let params: Vec<f64> = vec![];
199259

200-
let mut optimizer = Adam::new(None, None, None, 0);
201-
let updated_params = optimizer.step(&gradients);
260+
let mut optimizer = Adam::new(None, None, None, None, 0);
261+
let updated_params = optimizer.step(&gradients, &params);
202262

203263
assert_eq!(updated_params, vec![]);
204264
}
205265

266+
// ── Step: AdamW (weight_decay > 0.0) ─────────────────────────────────────
267+
268+
#[test]
269+
fn test_adamw_step_nonzero_params_applies_decay() {
270+
// When params are non-zero and weight_decay > 0.0, the decay term must pull
271+
// every parameter strictly closer to zero than the plain adaptive step would.
272+
// Comparing against a no-decay run avoids replicating the internal floating
273+
// point computation path and tests the property that actually matters.
274+
let gradients = vec![1.0, -2.0, 3.0];
275+
let params = vec![0.5, -0.5, 1.0];
276+
277+
let mut with_decay = Adam::new(None, None, None, Some(0.01), 3);
278+
let decayed = with_decay.step(&gradients, &params);
279+
280+
let mut no_decay = Adam::new(None, None, None, None, 3);
281+
let not_decayed = no_decay.step(&gradients, &params);
282+
283+
for i in 0..params.len() {
284+
assert!(
285+
decayed[i].abs() < not_decayed[i].abs(),
286+
"param[{i}]: with_decay={}, no_decay={}",
287+
decayed[i],
288+
not_decayed[i]
289+
);
290+
}
291+
}
292+
293+
#[test]
294+
fn test_adamw_step_weight_decay_zero_matches_adam() {
295+
// weight_decay = 0.0 must be numerically identical to standard Adam.
296+
let gradients = vec![9.0, -8.0, 7.0, -6.0, 5.0, -4.0, 3.0, -2.0, 1.0];
297+
let params = vec![0.0; 9];
298+
299+
let mut adamw = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), Some(0.0), 9);
300+
let mut adam = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), None, 9);
301+
302+
assert_eq!(
303+
adamw.step(&gradients, &params),
304+
adam.step(&gradients, &params)
305+
);
306+
}
307+
308+
#[test]
309+
fn test_adamw_step_decay_pulls_params_toward_zero() {
310+
// Each updated parameter must be closer to zero than its predecessor.
311+
let gradients = vec![1.0, -1.0, 2.0, -2.0];
312+
let params = vec![0.1, -0.1, 0.2, -0.2];
313+
314+
let mut optimizer = Adam::new(Some(0.01), Some((0.9, 0.999)), Some(1e-8), Some(0.01), 4);
315+
let updated = optimizer.step(&gradients, &params);
316+
317+
assert!(updated[0] < params[0]); // positive param, positive grad → decrease
318+
assert!(updated[1] > params[1]); // negative param, negative grad → increase
319+
assert!(updated[2] < params[2]);
320+
assert!(updated[3] > params[3]);
321+
}
322+
323+
// ── Step: shared edge cases ───────────────────────────────────────────────
324+
325+
#[test]
326+
#[should_panic(expected = "gradients and params must have the same length")]
327+
fn test_step_mismatched_lengths_panics() {
328+
let mut optimizer = Adam::new(None, None, None, None, 3);
329+
optimizer.step(&[1.0, 2.0, 3.0], &[0.0, 0.0]); // params too short
330+
}
331+
332+
// ── Convergence (slow; marked #[ignore]) ─────────────────────────────────
333+
206334
#[ignore]
207335
#[test]
208336
fn test_adam_step_iteratively_until_convergence_with_default_params() {
209337
const CONVERGENCE_THRESHOLD: f64 = 1e-5;
210338
let gradients = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
211339

212-
let mut optimizer = Adam::new(None, None, None, 6);
340+
let mut optimizer = Adam::new(None, None, None, None, 6);
213341

214342
let mut model_params = vec![0.0; 6];
215-
let mut updated_params = optimizer.step(&gradients);
343+
let mut updated_params = optimizer.step(&gradients, &model_params);
216344

217345
while (updated_params
218346
.iter()
@@ -226,7 +354,7 @@ mod tests {
226354
> CONVERGENCE_THRESHOLD
227355
{
228356
model_params = updated_params;
229-
updated_params = optimizer.step(&gradients);
357+
updated_params = optimizer.step(&gradients, &model_params);
230358
}
231359

232360
assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 6]);
@@ -250,10 +378,10 @@ mod tests {
250378
const CONVERGENCE_THRESHOLD: f64 = 1e-7;
251379
let gradients = vec![7.0, -8.0, 9.0, -10.0, 11.0, -12.0, 13.0];
252380

253-
let mut optimizer = Adam::new(Some(0.005), Some((0.8, 0.899)), Some(1e-5), 7);
381+
let mut optimizer = Adam::new(Some(0.005), Some((0.8, 0.899)), Some(1e-5), None, 7);
254382

255383
let mut model_params = vec![0.0; 7];
256-
let mut updated_params = optimizer.step(&gradients);
384+
let mut updated_params = optimizer.step(&gradients, &model_params);
257385

258386
while (updated_params
259387
.iter()
@@ -267,7 +395,7 @@ mod tests {
267395
> CONVERGENCE_THRESHOLD
268396
{
269397
model_params = updated_params;
270-
updated_params = optimizer.step(&gradients);
398+
updated_params = optimizer.step(&gradients, &model_params);
271399
}
272400

273401
assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 7]);
@@ -285,4 +413,33 @@ mod tests {
285413
]
286414
);
287415
}
416+
417+
#[ignore]
418+
#[test]
419+
fn test_adamw_step_iteratively_until_convergence() {
420+
const CONVERGENCE_THRESHOLD: f64 = 1e-5;
421+
let gradients = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
422+
423+
let mut optimizer = Adam::new(None, None, None, Some(0.0), 6);
424+
425+
let mut params = vec![0.0; 6];
426+
let mut updated = optimizer.step(&gradients, &params);
427+
428+
while (updated
429+
.iter()
430+
.zip(params.iter())
431+
.map(|(x, y)| x - y)
432+
.collect::<Vec<f64>>())
433+
.iter()
434+
.map(|&x| x.powi(2))
435+
.sum::<f64>()
436+
.sqrt()
437+
> CONVERGENCE_THRESHOLD
438+
{
439+
params = updated;
440+
updated = optimizer.step(&gradients, &params);
441+
}
442+
443+
assert_ne!(updated, params);
444+
}
288445
}

0 commit comments

Comments
 (0)