55//! learning problems. Boasting memory-efficient fast convergence rates, it sets and iteratively
66//! updates learning rates individually for each model parameter based on the gradient history.
77//!
8+ //! Setting `weight_decay > 0.0` enables the AdamW variant (Loshchilov & Hutter, 2019), which
9+ //! applies weight decay directly to the parameters rather than folding it into the gradients.
10+ //! This keeps the decay rate constant and independent of the gradient history — the key flaw
11+ //! that AdamW corrects over naive L2 regularization inside Adam. With `weight_decay = 0.0`
12+ //! (the default), the two algorithms are identical.
13+ //!
814//! ## Algorithm:
915//!
1016//! Given:
1117//! - α is the learning rate
1218//! - (β_1, β_2) are the exponential decay rates for moment estimates
1319//! - ϵ is any small value to prevent division by zero
20+ //! - λ is the weight decay coefficient (0.0 for standard Adam, > 0.0 for AdamW)
1421//! - g_t are the gradients at time step t
1522//! - m_t are the biased first moment estimates of the gradient at time step t
1623//! - v_t are the biased second raw moment estimates of the gradient at time step t
2835//! while θ_t not converged do
2936//! m_t = β_1 * m_{t−1} + (1 − β_1) * g_t
3037//! v_t = β_2 * v_{t−1} + (1 − β_2) * g_t^2
31- //! m_hat_t = m_t / 1 - β_1^t
32- //! v_hat_t = v_t / 1 - β_2^t
33- //! θ_t = θ_{t-1} − α * m_hat_t / (sqrt(v_hat_t) + ϵ)
38+ //! m_hat_t = m_t / ( 1 - β_1^t)
39+ //! v_hat_t = v_t / ( 1 - β_2^t)
40+ //! θ_t = θ_{t-1} − α * ( m_hat_t / (sqrt(v_hat_t) + ϵ) + λ * θ_{t-1} )
3441//!
3542//! ## Resources:
3643//! - Adam: A Method for Stochastic Optimization (by Diederik P. Kingma and Jimmy Ba):
3744//! - [https://arxiv.org/abs/1412.6980]
45+ //! - Decoupled Weight Decay Regularization (by Ilya Loshchilov and Frank Hutter):
46+ //! - [https://arxiv.org/abs/1711.05101]
3847//! - PyTorch Adam optimizer:
39- //! - [https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam]
48+ //! - [https://pytorch.org/docs/stable/generated/torch.optim.Adam.html]
49+ //! - PyTorch AdamW optimizer:
50+ //! - [https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html]
4051//!
4152pub struct Adam {
4253 learning_rate : f64 , // alpha: initial step size for iterative optimization
4354 betas : ( f64 , f64 ) , // betas: exponential decay rates for moment estimates
4455 epsilon : f64 , // epsilon: prevent division by zero
56+ weight_decay : f64 , // lambda: decoupled weight decay coefficient (0.0 = standard Adam)
4557 m : Vec < f64 > , // m: biased first moment estimate of the gradient vector
4658 v : Vec < f64 > , // v: biased second raw moment estimate of the gradient vector
4759 t : usize , // t: time step
@@ -52,20 +64,38 @@ impl Adam {
5264 learning_rate : Option < f64 > ,
5365 betas : Option < ( f64 , f64 ) > ,
5466 epsilon : Option < f64 > ,
67+ weight_decay : Option < f64 > ,
5568 params_len : usize ,
5669 ) -> Self {
5770 Adam {
5871 learning_rate : learning_rate. unwrap_or ( 1e-3 ) , // typical good default lr
5972 betas : betas. unwrap_or ( ( 0.9 , 0.999 ) ) , // typical good default decay rates
6073 epsilon : epsilon. unwrap_or ( 1e-8 ) , // typical good default epsilon
74+ weight_decay : weight_decay. unwrap_or ( 0.0 ) , // 0.0 = standard Adam, > 0.0 = AdamW
6175 m : vec ! [ 0.0 ; params_len] , // first moment vector elements all initialized to zero
6276 v : vec ! [ 0.0 ; params_len] , // second moment vector elements all initialized to zero
6377 t : 0 , // time step initialized to zero
6478 }
6579 }
6680
67- pub fn step ( & mut self , gradients : & [ f64 ] ) -> Vec < f64 > {
68- let mut model_params = vec ! [ 0.0 ; gradients. len( ) ] ;
81+ /// Computes one update step.
82+ ///
83+ /// `params` holds the current parameter values θ_{t-1}. When `weight_decay`
84+ /// is `0.0` the update is standard Adam; any positive value applies the AdamW
85+ /// decoupled decay term `λ * θ_{t-1}` directly to the parameters, independent
86+ /// of the adaptive scaling.
87+ ///
88+ /// # Panics
89+ ///
90+ /// Panics if `gradients` and `params` have different lengths.
91+ pub fn step ( & mut self , gradients : & [ f64 ] , params : & [ f64 ] ) -> Vec < f64 > {
92+ assert_eq ! (
93+ gradients. len( ) ,
94+ params. len( ) ,
95+ "gradients and params must have the same length"
96+ ) ;
97+
98+ let mut updated_params = vec ! [ 0.0 ; params. len( ) ] ;
6999 self . t += 1 ;
70100
71101 for i in 0 ..gradients. len ( ) {
@@ -77,83 +107,111 @@ impl Adam {
77107 let m_hat = self . m [ i] / ( 1.0 - self . betas . 0 . powi ( self . t as i32 ) ) ;
78108 let v_hat = self . v [ i] / ( 1.0 - self . betas . 1 . powi ( self . t as i32 ) ) ;
79109
80- // update model parameters
81- model_params[ i] -= self . learning_rate * m_hat / ( v_hat. sqrt ( ) + self . epsilon ) ;
110+ // Adaptive gradient step — preserves the original (lr * m_hat) / denom
111+ // operator order so floating-point results are identical to standard Adam
112+ // when weight_decay = 0.0. The decoupled decay term is added separately
113+ // so it does not interact with the adaptive scaling.
114+ updated_params[ i] = params[ i]
115+ - self . learning_rate * m_hat / ( v_hat. sqrt ( ) + self . epsilon )
116+ - self . learning_rate * self . weight_decay * params[ i] ;
82117 }
83- model_params // return updated model parameters
118+ updated_params // return updated model parameters
84119 }
85120}
86121
87122#[ cfg( test) ]
88123mod tests {
89124 use super :: * ;
90125
126+ // ── Initialisation ────────────────────────────────────────────────────────
127+
91128 #[ test]
92129 fn test_adam_init_default_values ( ) {
93- let optimizer = Adam :: new ( None , None , None , 1 ) ;
130+ let optimizer = Adam :: new ( None , None , None , None , 1 ) ;
94131
95132 assert_eq ! ( optimizer. learning_rate, 0.001 ) ;
96133 assert_eq ! ( optimizer. betas, ( 0.9 , 0.999 ) ) ;
97134 assert_eq ! ( optimizer. epsilon, 1e-8 ) ;
135+ assert_eq ! ( optimizer. weight_decay, 0.0 ) ;
98136 assert_eq ! ( optimizer. m, vec![ 0.0 ; 1 ] ) ;
99137 assert_eq ! ( optimizer. v, vec![ 0.0 ; 1 ] ) ;
100138 assert_eq ! ( optimizer. t, 0 ) ;
101139 }
102140
103141 #[ test]
104142 fn test_adam_init_custom_lr_value ( ) {
105- let optimizer = Adam :: new ( Some ( 0.9 ) , None , None , 2 ) ;
143+ let optimizer = Adam :: new ( Some ( 0.9 ) , None , None , None , 2 ) ;
106144
107145 assert_eq ! ( optimizer. learning_rate, 0.9 ) ;
108146 assert_eq ! ( optimizer. betas, ( 0.9 , 0.999 ) ) ;
109147 assert_eq ! ( optimizer. epsilon, 1e-8 ) ;
148+ assert_eq ! ( optimizer. weight_decay, 0.0 ) ;
110149 assert_eq ! ( optimizer. m, vec![ 0.0 ; 2 ] ) ;
111150 assert_eq ! ( optimizer. v, vec![ 0.0 ; 2 ] ) ;
112151 assert_eq ! ( optimizer. t, 0 ) ;
113152 }
114153
115154 #[ test]
116155 fn test_adam_init_custom_betas_value ( ) {
117- let optimizer = Adam :: new ( None , Some ( ( 0.8 , 0.899 ) ) , None , 3 ) ;
156+ let optimizer = Adam :: new ( None , Some ( ( 0.8 , 0.899 ) ) , None , None , 3 ) ;
118157
119158 assert_eq ! ( optimizer. learning_rate, 0.001 ) ;
120159 assert_eq ! ( optimizer. betas, ( 0.8 , 0.899 ) ) ;
121160 assert_eq ! ( optimizer. epsilon, 1e-8 ) ;
161+ assert_eq ! ( optimizer. weight_decay, 0.0 ) ;
122162 assert_eq ! ( optimizer. m, vec![ 0.0 ; 3 ] ) ;
123163 assert_eq ! ( optimizer. v, vec![ 0.0 ; 3 ] ) ;
124164 assert_eq ! ( optimizer. t, 0 ) ;
125165 }
126166
127167 #[ test]
128168 fn test_adam_init_custom_epsilon_value ( ) {
129- let optimizer = Adam :: new ( None , None , Some ( 1e-10 ) , 4 ) ;
169+ let optimizer = Adam :: new ( None , None , Some ( 1e-10 ) , None , 4 ) ;
130170
131171 assert_eq ! ( optimizer. learning_rate, 0.001 ) ;
132172 assert_eq ! ( optimizer. betas, ( 0.9 , 0.999 ) ) ;
133173 assert_eq ! ( optimizer. epsilon, 1e-10 ) ;
174+ assert_eq ! ( optimizer. weight_decay, 0.0 ) ;
134175 assert_eq ! ( optimizer. m, vec![ 0.0 ; 4 ] ) ;
135176 assert_eq ! ( optimizer. v, vec![ 0.0 ; 4 ] ) ;
136177 assert_eq ! ( optimizer. t, 0 ) ;
137178 }
138179
180+ #[ test]
181+ fn test_adam_init_custom_weight_decay_value ( ) {
182+ let optimizer = Adam :: new ( None , None , None , Some ( 0.1 ) , 3 ) ;
183+
184+ assert_eq ! ( optimizer. learning_rate, 0.001 ) ;
185+ assert_eq ! ( optimizer. betas, ( 0.9 , 0.999 ) ) ;
186+ assert_eq ! ( optimizer. epsilon, 1e-8 ) ;
187+ assert_eq ! ( optimizer. weight_decay, 0.1 ) ;
188+ assert_eq ! ( optimizer. m, vec![ 0.0 ; 3 ] ) ;
189+ assert_eq ! ( optimizer. v, vec![ 0.0 ; 3 ] ) ;
190+ assert_eq ! ( optimizer. t, 0 ) ;
191+ }
192+
139193 #[ test]
140194 fn test_adam_init_all_custom_values ( ) {
141- let optimizer = Adam :: new ( Some ( 1.0 ) , Some ( ( 0.001 , 0.099 ) ) , Some ( 1e-1 ) , 5 ) ;
195+ let optimizer = Adam :: new ( Some ( 1.0 ) , Some ( ( 0.001 , 0.099 ) ) , Some ( 1e-1 ) , Some ( 0.05 ) , 5 ) ;
142196
143197 assert_eq ! ( optimizer. learning_rate, 1.0 ) ;
144198 assert_eq ! ( optimizer. betas, ( 0.001 , 0.099 ) ) ;
145199 assert_eq ! ( optimizer. epsilon, 1e-1 ) ;
200+ assert_eq ! ( optimizer. weight_decay, 0.05 ) ;
146201 assert_eq ! ( optimizer. m, vec![ 0.0 ; 5 ] ) ;
147202 assert_eq ! ( optimizer. v, vec![ 0.0 ; 5 ] ) ;
148203 assert_eq ! ( optimizer. t, 0 ) ;
149204 }
150205
206+ // ── Step: standard Adam (weight_decay = 0.0) ──────────────────────────────
207+
151208 #[ test]
152209 fn test_adam_step_default_params ( ) {
153210 let gradients = vec ! [ -1.0 , 2.0 , -3.0 , 4.0 , -5.0 , 6.0 , -7.0 , 8.0 ] ;
211+ let params = vec ! [ 0.0 ; 8 ] ;
154212
155- let mut optimizer = Adam :: new ( None , None , None , 8 ) ;
156- let updated_params = optimizer. step ( & gradients) ;
213+ let mut optimizer = Adam :: new ( None , None , None , None , 8 ) ;
214+ let updated_params = optimizer. step ( & gradients, & params ) ;
157215
158216 assert_eq ! (
159217 updated_params,
@@ -173,9 +231,10 @@ mod tests {
173231 #[ test]
174232 fn test_adam_step_custom_params ( ) {
175233 let gradients = vec ! [ 9.0 , -8.0 , 7.0 , -6.0 , 5.0 , -4.0 , 3.0 , -2.0 , 1.0 ] ;
234+ let params = vec ! [ 0.0 ; 9 ] ;
176235
177- let mut optimizer = Adam :: new ( Some ( 0.005 ) , Some ( ( 0.5 , 0.599 ) ) , Some ( 1e-5 ) , 9 ) ;
178- let updated_params = optimizer. step ( & gradients) ;
236+ let mut optimizer = Adam :: new ( Some ( 0.005 ) , Some ( ( 0.5 , 0.599 ) ) , Some ( 1e-5 ) , None , 9 ) ;
237+ let updated_params = optimizer. step ( & gradients, & params ) ;
179238
180239 assert_eq ! (
181240 updated_params,
@@ -195,24 +254,93 @@ mod tests {
195254
196255 #[ test]
197256 fn test_adam_step_empty_gradients_array ( ) {
198- let gradients = vec ! [ ] ;
257+ let gradients: Vec < f64 > = vec ! [ ] ;
258+ let params: Vec < f64 > = vec ! [ ] ;
199259
200- let mut optimizer = Adam :: new ( None , None , None , 0 ) ;
201- let updated_params = optimizer. step ( & gradients) ;
260+ let mut optimizer = Adam :: new ( None , None , None , None , 0 ) ;
261+ let updated_params = optimizer. step ( & gradients, & params ) ;
202262
203263 assert_eq ! ( updated_params, vec![ ] ) ;
204264 }
205265
266+ // ── Step: AdamW (weight_decay > 0.0) ─────────────────────────────────────
267+
268+ #[ test]
269+ fn test_adamw_step_nonzero_params_applies_decay ( ) {
270+ // When params are non-zero and weight_decay > 0.0, the decay term must pull
271+ // every parameter strictly closer to zero than the plain adaptive step would.
272+ // Comparing against a no-decay run avoids replicating the internal floating
273+ // point computation path and tests the property that actually matters.
274+ let gradients = vec ! [ 1.0 , -2.0 , 3.0 ] ;
275+ let params = vec ! [ 0.5 , -0.5 , 1.0 ] ;
276+
277+ let mut with_decay = Adam :: new ( None , None , None , Some ( 0.01 ) , 3 ) ;
278+ let decayed = with_decay. step ( & gradients, & params) ;
279+
280+ let mut no_decay = Adam :: new ( None , None , None , None , 3 ) ;
281+ let not_decayed = no_decay. step ( & gradients, & params) ;
282+
283+ for i in 0 ..params. len ( ) {
284+ assert ! (
285+ decayed[ i] . abs( ) < not_decayed[ i] . abs( ) ,
286+ "param[{i}]: with_decay={}, no_decay={}" ,
287+ decayed[ i] ,
288+ not_decayed[ i]
289+ ) ;
290+ }
291+ }
292+
293+ #[ test]
294+ fn test_adamw_step_weight_decay_zero_matches_adam ( ) {
295+ // weight_decay = 0.0 must be numerically identical to standard Adam.
296+ let gradients = vec ! [ 9.0 , -8.0 , 7.0 , -6.0 , 5.0 , -4.0 , 3.0 , -2.0 , 1.0 ] ;
297+ let params = vec ! [ 0.0 ; 9 ] ;
298+
299+ let mut adamw = Adam :: new ( Some ( 0.005 ) , Some ( ( 0.5 , 0.599 ) ) , Some ( 1e-5 ) , Some ( 0.0 ) , 9 ) ;
300+ let mut adam = Adam :: new ( Some ( 0.005 ) , Some ( ( 0.5 , 0.599 ) ) , Some ( 1e-5 ) , None , 9 ) ;
301+
302+ assert_eq ! (
303+ adamw. step( & gradients, & params) ,
304+ adam. step( & gradients, & params)
305+ ) ;
306+ }
307+
308+ #[ test]
309+ fn test_adamw_step_decay_pulls_params_toward_zero ( ) {
310+ // Each updated parameter must be closer to zero than its predecessor.
311+ let gradients = vec ! [ 1.0 , -1.0 , 2.0 , -2.0 ] ;
312+ let params = vec ! [ 0.1 , -0.1 , 0.2 , -0.2 ] ;
313+
314+ let mut optimizer = Adam :: new ( Some ( 0.01 ) , Some ( ( 0.9 , 0.999 ) ) , Some ( 1e-8 ) , Some ( 0.01 ) , 4 ) ;
315+ let updated = optimizer. step ( & gradients, & params) ;
316+
317+ assert ! ( updated[ 0 ] < params[ 0 ] ) ; // positive param, positive grad → decrease
318+ assert ! ( updated[ 1 ] > params[ 1 ] ) ; // negative param, negative grad → increase
319+ assert ! ( updated[ 2 ] < params[ 2 ] ) ;
320+ assert ! ( updated[ 3 ] > params[ 3 ] ) ;
321+ }
322+
323+ // ── Step: shared edge cases ───────────────────────────────────────────────
324+
325+ #[ test]
326+ #[ should_panic( expected = "gradients and params must have the same length" ) ]
327+ fn test_step_mismatched_lengths_panics ( ) {
328+ let mut optimizer = Adam :: new ( None , None , None , None , 3 ) ;
329+ optimizer. step ( & [ 1.0 , 2.0 , 3.0 ] , & [ 0.0 , 0.0 ] ) ; // params too short
330+ }
331+
332+ // ── Convergence (slow; marked #[ignore]) ─────────────────────────────────
333+
206334 #[ ignore]
207335 #[ test]
208336 fn test_adam_step_iteratively_until_convergence_with_default_params ( ) {
209337 const CONVERGENCE_THRESHOLD : f64 = 1e-5 ;
210338 let gradients = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
211339
212- let mut optimizer = Adam :: new ( None , None , None , 6 ) ;
340+ let mut optimizer = Adam :: new ( None , None , None , None , 6 ) ;
213341
214342 let mut model_params = vec ! [ 0.0 ; 6 ] ;
215- let mut updated_params = optimizer. step ( & gradients) ;
343+ let mut updated_params = optimizer. step ( & gradients, & model_params ) ;
216344
217345 while ( updated_params
218346 . iter ( )
@@ -226,7 +354,7 @@ mod tests {
226354 > CONVERGENCE_THRESHOLD
227355 {
228356 model_params = updated_params;
229- updated_params = optimizer. step ( & gradients) ;
357+ updated_params = optimizer. step ( & gradients, & model_params ) ;
230358 }
231359
232360 assert ! ( updated_params < vec![ CONVERGENCE_THRESHOLD ; 6 ] ) ;
@@ -250,10 +378,10 @@ mod tests {
250378 const CONVERGENCE_THRESHOLD : f64 = 1e-7 ;
251379 let gradients = vec ! [ 7.0 , -8.0 , 9.0 , -10.0 , 11.0 , -12.0 , 13.0 ] ;
252380
253- let mut optimizer = Adam :: new ( Some ( 0.005 ) , Some ( ( 0.8 , 0.899 ) ) , Some ( 1e-5 ) , 7 ) ;
381+ let mut optimizer = Adam :: new ( Some ( 0.005 ) , Some ( ( 0.8 , 0.899 ) ) , Some ( 1e-5 ) , None , 7 ) ;
254382
255383 let mut model_params = vec ! [ 0.0 ; 7 ] ;
256- let mut updated_params = optimizer. step ( & gradients) ;
384+ let mut updated_params = optimizer. step ( & gradients, & model_params ) ;
257385
258386 while ( updated_params
259387 . iter ( )
@@ -267,7 +395,7 @@ mod tests {
267395 > CONVERGENCE_THRESHOLD
268396 {
269397 model_params = updated_params;
270- updated_params = optimizer. step ( & gradients) ;
398+ updated_params = optimizer. step ( & gradients, & model_params ) ;
271399 }
272400
273401 assert ! ( updated_params < vec![ CONVERGENCE_THRESHOLD ; 7 ] ) ;
@@ -285,4 +413,33 @@ mod tests {
285413 ]
286414 ) ;
287415 }
416+
417+ #[ ignore]
418+ #[ test]
419+ fn test_adamw_step_iteratively_until_convergence ( ) {
420+ const CONVERGENCE_THRESHOLD : f64 = 1e-5 ;
421+ let gradients = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ] ;
422+
423+ let mut optimizer = Adam :: new ( None , None , None , Some ( 0.0 ) , 6 ) ;
424+
425+ let mut params = vec ! [ 0.0 ; 6 ] ;
426+ let mut updated = optimizer. step ( & gradients, & params) ;
427+
428+ while ( updated
429+ . iter ( )
430+ . zip ( params. iter ( ) )
431+ . map ( |( x, y) | x - y)
432+ . collect :: < Vec < f64 > > ( ) )
433+ . iter ( )
434+ . map ( |& x| x. powi ( 2 ) )
435+ . sum :: < f64 > ( )
436+ . sqrt ( )
437+ > CONVERGENCE_THRESHOLD
438+ {
439+ params = updated;
440+ updated = optimizer. step ( & gradients, & params) ;
441+ }
442+
443+ assert_ne ! ( updated, params) ;
444+ }
288445}
0 commit comments