@@ -12,13 +12,29 @@ use rust_lstm::{
1212 ScheduledLSTMTrainer , ScheduledOptimizer , StepLR , TrainingConfig , WarmupScheduler ,
1313} ;
1414
15+ pub const DEMO_TRAIN_SEQUENCES : usize = 12 ;
16+ pub const DEMO_VAL_SEQUENCES : usize = 4 ;
17+ pub const DEMO_SEQUENCE_LENGTH : usize = 6 ;
18+ pub const DEMO_HIDDEN_SIZE : usize = 4 ;
19+ pub const DEMO_ADVANCED_HIDDEN_SIZE : usize = 6 ;
20+ pub const DEMO_POLYNOMIAL_EPOCHS : usize = 5 ;
21+ pub const DEMO_CYCLICAL_EPOCHS : usize = 5 ;
22+ pub const DEMO_WARMUP_EPOCHS : usize = 5 ;
23+ pub const DEMO_ADVANCED_EPOCHS : usize = 5 ;
24+ pub const DEMO_POLYNOMIAL_ITERS : usize = DEMO_POLYNOMIAL_EPOCHS ;
25+ pub const DEMO_CYCLICAL_STEP_SIZE : usize = 2 ;
26+ pub const DEMO_WARMUP_EPOCH_COUNT : usize = 2 ;
27+ pub const DEMO_BASE_STEP_SIZE : usize = 2 ;
28+ pub const DEMO_VISUALIZATION_STEP_SIZE : usize = 2 ;
29+ pub const DEMO_VISUALIZATION_STEPS : usize = 20 ;
30+
1531fn main ( ) {
1632 println ! ( "🚀 Advanced Learning Rate Scheduling for Rust-LSTM" ) ;
1733 println ! ( "===================================================\n " ) ;
1834
1935 // Generate sample training data
20- let train_data = generate_sine_wave_data ( 50 , 0.0 ) ;
21- let val_data = generate_sine_wave_data ( 10 , 1000.0 ) ;
36+ let train_data = generate_sine_wave_data ( DEMO_TRAIN_SEQUENCES , 0.0 ) ;
37+ let val_data = generate_sine_wave_data ( DEMO_VAL_SEQUENCES , 1000.0 ) ;
2238
2339 // 1. Polynomial Decay Example
2440 polynomial_decay_example ( & train_data, & val_data) ;
@@ -43,24 +59,18 @@ fn polynomial_decay_example(
4359 println ! ( "1️⃣ Polynomial Decay Example" ) ;
4460 println ! ( " Smoothly decays LR using polynomial function\n " ) ;
4561
46- let network = LSTMNetwork :: new ( 1 , 8 , 1 ) ;
62+ let network = LSTMNetwork :: new ( 1 , DEMO_HIDDEN_SIZE , 1 ) ;
4763
4864 let loss_function = MSELoss ;
4965 let scheduled_optimizer = ScheduledOptimizer :: polynomial (
5066 Adam :: new ( 0.01 ) ,
51- 0.01 , // base_lr
52- 25 , // total_iters
53- 2.0 , // power
54- 0.001 , // end_lr
67+ 0.01 , // base_lr
68+ DEMO_POLYNOMIAL_ITERS , // total_iters
69+ 2.0 , // power
70+ 0.001 , // end_lr
5571 ) ;
5672
57- let config = TrainingConfig {
58- epochs : 30 ,
59- print_every : 5 ,
60- clip_gradient : Some ( 1.0 ) ,
61- log_lr_changes : true ,
62- early_stopping : None ,
63- } ;
73+ let config = polynomial_decay_training_config ( ) ;
6474
6575 let mut trainer =
6676 ScheduledLSTMTrainer :: new ( network, loss_function, scheduled_optimizer) . with_config ( config) ;
@@ -80,23 +90,17 @@ fn cyclical_lr_examples(
8090
8191 // 2a. Triangular Cyclical LR
8292 println ! ( "2a. Triangular Cyclical LR" ) ;
83- let network = LSTMNetwork :: new ( 1 , 8 , 1 ) ;
93+ let network = LSTMNetwork :: new ( 1 , DEMO_HIDDEN_SIZE , 1 ) ;
8494
8595 let loss_function = MSELoss ;
8696 let scheduled_optimizer = ScheduledOptimizer :: cyclical (
8797 Adam :: new ( 0.001 ) ,
88- 0.001 , // base_lr
89- 0.01 , // max_lr
90- 8 , // step_size
98+ 0.001 , // base_lr
99+ 0.01 , // max_lr
100+ DEMO_CYCLICAL_STEP_SIZE , // step_size
91101 ) ;
92102
93- let config = TrainingConfig {
94- epochs : 25 ,
95- print_every : 5 ,
96- clip_gradient : Some ( 1.0 ) ,
97- log_lr_changes : false , // Too frequent for cyclical
98- early_stopping : None ,
99- } ;
103+ let config = cyclical_lr_training_config ( ) ;
100104
101105 let mut trainer =
102106 ScheduledLSTMTrainer :: new ( network, loss_function, scheduled_optimizer) . with_config ( config) ;
@@ -106,23 +110,17 @@ fn cyclical_lr_examples(
106110
107111 // 2b. Triangular2 Cyclical LR (halving amplitude each cycle)
108112 println ! ( "2b. Triangular2 Cyclical LR (halving amplitude each cycle)" ) ;
109- let network = LSTMNetwork :: new ( 1 , 8 , 1 ) ;
113+ let network = LSTMNetwork :: new ( 1 , DEMO_HIDDEN_SIZE , 1 ) ;
110114
111115 let loss_function = MSELoss ;
112116 let scheduled_optimizer = ScheduledOptimizer :: cyclical_triangular2 (
113117 Adam :: new ( 0.001 ) ,
114- 0.001 , // base_lr
115- 0.01 , // max_lr
116- 8 , // step_size
118+ 0.001 , // base_lr
119+ 0.01 , // max_lr
120+ DEMO_CYCLICAL_STEP_SIZE , // step_size
117121 ) ;
118122
119- let config2 = TrainingConfig {
120- epochs : 25 ,
121- print_every : 5 ,
122- clip_gradient : Some ( 1.0 ) ,
123- log_lr_changes : false ,
124- early_stopping : None ,
125- } ;
123+ let config2 = cyclical_lr_training_config ( ) ;
126124
127125 let mut trainer =
128126 ScheduledLSTMTrainer :: new ( network, loss_function, scheduled_optimizer) . with_config ( config2) ;
@@ -132,24 +130,18 @@ fn cyclical_lr_examples(
132130
133131 // 2c. ExpRange Cyclical LR (exponential scaling)
134132 println ! ( "2c. ExpRange Cyclical LR (exponential scaling)" ) ;
135- let network = LSTMNetwork :: new ( 1 , 8 , 1 ) ;
133+ let network = LSTMNetwork :: new ( 1 , DEMO_HIDDEN_SIZE , 1 ) ;
136134
137135 let loss_function = MSELoss ;
138136 let scheduled_optimizer = ScheduledOptimizer :: cyclical_exp_range (
139137 Adam :: new ( 0.001 ) ,
140- 0.001 , // base_lr
141- 0.01 , // max_lr
142- 8 , // step_size
143- 0.95 , // gamma
138+ 0.001 , // base_lr
139+ 0.01 , // max_lr
140+ DEMO_CYCLICAL_STEP_SIZE , // step_size
141+ 0.95 , // gamma
144142 ) ;
145143
146- let config3 = TrainingConfig {
147- epochs : 25 ,
148- print_every : 5 ,
149- clip_gradient : Some ( 1.0 ) ,
150- log_lr_changes : false ,
151- early_stopping : None ,
152- } ;
144+ let config3 = cyclical_lr_training_config ( ) ;
153145
154146 let mut trainer =
155147 ScheduledLSTMTrainer :: new ( network, loss_function, scheduled_optimizer) . with_config ( config3) ;
@@ -167,26 +159,15 @@ fn warmup_scheduler_example(
167159 println ! ( "3️⃣ Warmup Scheduler Example" ) ;
168160 println ! ( " Gradually increases LR during warmup, then applies base scheduler\n " ) ;
169161
170- let network = LSTMNetwork :: new ( 1 , 8 , 1 ) ;
162+ let network = LSTMNetwork :: new ( 1 , DEMO_HIDDEN_SIZE , 1 ) ;
171163
172- // Create warmup scheduler with step decay after warmup
173- let base_scheduler = StepLR :: new ( 10 , 0.5 ) ; // Reduce by half every 10 epochs
174- let warmup_scheduler = WarmupScheduler :: new (
175- 5 , // warmup_epochs
176- base_scheduler, // base_scheduler
177- 0.001 , // warmup_start_lr
178- ) ;
164+ let base_scheduler = StepLR :: new ( DEMO_BASE_STEP_SIZE , 0.5 ) ;
165+ let warmup_scheduler = WarmupScheduler :: new ( DEMO_WARMUP_EPOCH_COUNT , base_scheduler, 0.001 ) ;
179166
180167 let loss_function = MSELoss ;
181168 let scheduled_optimizer = ScheduledOptimizer :: new ( Adam :: new ( 0.01 ) , warmup_scheduler, 0.01 ) ;
182169
183- let config = TrainingConfig {
184- epochs : 30 ,
185- print_every : 3 ,
186- clip_gradient : Some ( 1.0 ) ,
187- log_lr_changes : true ,
188- early_stopping : None ,
189- } ;
170+ let config = warmup_scheduler_training_config ( ) ;
190171
191172 let mut trainer =
192173 ScheduledLSTMTrainer :: new ( network, loss_function, scheduled_optimizer) . with_config ( config) ;
@@ -202,21 +183,27 @@ fn schedule_visualization() {
202183 println ! ( " ASCII visualization of different schedulers\n " ) ;
203184
204185 // Visualize StepLR
205- println ! ( "StepLR (step_size=10 , gamma=0.5):" ) ;
206- let step_scheduler = StepLR :: new ( 10 , 0.5 ) ;
207- LRScheduleVisualizer :: print_schedule ( step_scheduler, 0.01 , 50 , 60 , 10 ) ;
186+ println ! ( "StepLR (step_size=2 , gamma=0.5):" ) ;
187+ let step_scheduler = StepLR :: new ( DEMO_VISUALIZATION_STEP_SIZE , 0.5 ) ;
188+ LRScheduleVisualizer :: print_schedule ( step_scheduler, 0.01 , DEMO_VISUALIZATION_STEPS , 40 , 5 ) ;
208189 println ! ( ) ;
209190
210191 // Visualize PolynomialLR
211192 println ! ( "PolynomialLR (power=2.0, end_lr=0.001):" ) ;
212- let poly_scheduler = PolynomialLR :: new ( 50 , 2.0 , 0.001 ) ;
213- LRScheduleVisualizer :: print_schedule ( poly_scheduler, 0.01 , 50 , 60 , 10 ) ;
193+ let poly_scheduler = PolynomialLR :: new ( DEMO_VISUALIZATION_STEPS , 2.0 , 0.001 ) ;
194+ LRScheduleVisualizer :: print_schedule ( poly_scheduler, 0.01 , DEMO_VISUALIZATION_STEPS , 40 , 5 ) ;
214195 println ! ( ) ;
215196
216197 // Visualize CyclicalLR
217- println ! ( "CyclicalLR Triangular (base_lr=0.001, max_lr=0.01, step_size=8):" ) ;
218- let cyclical_scheduler = CyclicalLR :: new ( 0.001 , 0.01 , 8 ) ;
219- LRScheduleVisualizer :: print_schedule ( cyclical_scheduler, 0.001 , 50 , 60 , 10 ) ;
198+ println ! ( "CyclicalLR Triangular (base_lr=0.001, max_lr=0.01, step_size=2):" ) ;
199+ let cyclical_scheduler = CyclicalLR :: new ( 0.001 , 0.01 , DEMO_CYCLICAL_STEP_SIZE ) ;
200+ LRScheduleVisualizer :: print_schedule (
201+ cyclical_scheduler,
202+ 0.001 ,
203+ DEMO_VISUALIZATION_STEPS ,
204+ 40 ,
205+ 5 ,
206+ ) ;
220207 println ! ( ) ;
221208
222209 println ! ( "----------------------------------------\n " ) ;
@@ -230,25 +217,20 @@ fn advanced_training_example(
230217 println ! ( " Warmup + Cyclical LR + Dropout + Gradient Clipping\n " ) ;
231218
232219 // Create network with dropout
233- let network = LSTMNetwork :: new ( 1 , 16 , 1 )
220+ let network = LSTMNetwork :: new ( 1 , DEMO_ADVANCED_HIDDEN_SIZE , 1 )
234221 . with_input_dropout ( 0.1 , true ) // Variational dropout
235222 . with_recurrent_dropout ( 0.2 , true ) // Variational recurrent dropout
236223 . with_output_dropout ( 0.1 ) ; // Standard output dropout
237224
238225 // Create warmup scheduler with cyclical base scheduler
239- let base_scheduler = CyclicalLR :: new ( 0.001 , 0.01 , 10 ) . with_mode ( CyclicalMode :: Triangular2 ) ;
240- let warmup_scheduler = WarmupScheduler :: new ( 5 , base_scheduler, 0.0001 ) ;
226+ let base_scheduler =
227+ CyclicalLR :: new ( 0.001 , 0.01 , DEMO_CYCLICAL_STEP_SIZE ) . with_mode ( CyclicalMode :: Triangular2 ) ;
228+ let warmup_scheduler = WarmupScheduler :: new ( DEMO_WARMUP_EPOCH_COUNT , base_scheduler, 0.0001 ) ;
241229
242230 let loss_function = MSELoss ;
243231 let scheduled_optimizer = ScheduledOptimizer :: new ( Adam :: new ( 0.01 ) , warmup_scheduler, 0.01 ) ;
244232
245- let config = TrainingConfig {
246- epochs : 40 ,
247- print_every : 5 ,
248- clip_gradient : Some ( 1.0 ) , // Gradient clipping
249- log_lr_changes : false , // Too frequent for cyclical
250- early_stopping : None ,
251- } ;
233+ let config = advanced_training_config ( ) ;
252234
253235 let mut trainer =
254236 ScheduledLSTMTrainer :: new ( network, loss_function, scheduled_optimizer) . with_config ( config) ;
@@ -272,18 +254,17 @@ fn advanced_training_example(
272254 println ! ( "\n ✅ Advanced training complete!" ) ;
273255}
274256
275- fn generate_sine_wave_data (
257+ pub fn generate_sine_wave_data (
276258 num_sequences : usize ,
277259 offset : f64 ,
278260) -> Vec < ( Vec < Array2 < f64 > > , Vec < Array2 < f64 > > ) > {
279261 let mut data = Vec :: new ( ) ;
280262
281263 for i in 0 ..num_sequences {
282- let sequence_length = 8 ;
283264 let mut inputs = Vec :: new ( ) ;
284265 let mut targets = Vec :: new ( ) ;
285266
286- for t in 0 ..sequence_length {
267+ for t in 0 ..DEMO_SEQUENCE_LENGTH {
287268 let x = ( offset + i as f64 * 0.1 + t as f64 * 0.2 ) . sin ( ) ;
288269 let y = ( offset + i as f64 * 0.1 + ( t + 1 ) as f64 * 0.2 ) . sin ( ) ;
289270
@@ -297,19 +278,59 @@ fn generate_sine_wave_data(
297278 data
298279}
299280
281+ pub fn polynomial_decay_training_config ( ) -> TrainingConfig {
282+ TrainingConfig {
283+ epochs : DEMO_POLYNOMIAL_EPOCHS ,
284+ print_every : 1 ,
285+ clip_gradient : Some ( 1.0 ) ,
286+ log_lr_changes : true ,
287+ early_stopping : None ,
288+ }
289+ }
290+
291+ pub fn cyclical_lr_training_config ( ) -> TrainingConfig {
292+ TrainingConfig {
293+ epochs : DEMO_CYCLICAL_EPOCHS ,
294+ print_every : 1 ,
295+ clip_gradient : Some ( 1.0 ) ,
296+ log_lr_changes : false ,
297+ early_stopping : None ,
298+ }
299+ }
300+
301+ pub fn warmup_scheduler_training_config ( ) -> TrainingConfig {
302+ TrainingConfig {
303+ epochs : DEMO_WARMUP_EPOCHS ,
304+ print_every : 1 ,
305+ clip_gradient : Some ( 1.0 ) ,
306+ log_lr_changes : true ,
307+ early_stopping : None ,
308+ }
309+ }
310+
311+ pub fn advanced_training_config ( ) -> TrainingConfig {
312+ TrainingConfig {
313+ epochs : DEMO_ADVANCED_EPOCHS ,
314+ print_every : 1 ,
315+ clip_gradient : Some ( 1.0 ) ,
316+ log_lr_changes : false ,
317+ early_stopping : None ,
318+ }
319+ }
320+
300321#[ cfg( test) ]
301322mod tests {
302323 use super :: * ;
303- use rust_lstm:: SGD ;
304324
305325 #[ test]
306326 fn test_advanced_schedulers ( ) {
307327 // Test polynomial scheduler
308328 let poly_scheduler = PolynomialLR :: new ( 100 , 2.0 , 0.01 ) ;
309- let schedule = LRScheduleVisualizer :: generate_schedule ( poly_scheduler, 0.1 , 100 ) ;
310- assert_eq ! ( schedule. len( ) , 100 ) ;
329+ let schedule = LRScheduleVisualizer :: generate_schedule ( poly_scheduler, 0.1 , 101 ) ;
330+ assert_eq ! ( schedule. len( ) , 101 ) ;
311331 assert_eq ! ( schedule[ 0 ] . 1 , 0.1 ) ;
312- assert ! ( ( schedule[ 99 ] . 1 - 0.01 ) . abs( ) < 1e-10 ) ;
332+ assert ! ( schedule[ 99 ] . 1 < schedule[ 0 ] . 1 ) ;
333+ assert ! ( ( schedule[ 100 ] . 1 - 0.01 ) . abs( ) < 1e-10 ) ;
313334
314335 // Test cyclical scheduler
315336 let cyclical_scheduler = CyclicalLR :: new ( 0.01 , 0.1 , 10 ) ;
0 commit comments