diff --git a/examples/advanced_lr_scheduling.rs b/examples/advanced_lr_scheduling.rs index 6b58456..d251f51 100644 --- a/examples/advanced_lr_scheduling.rs +++ b/examples/advanced_lr_scheduling.rs @@ -12,13 +12,29 @@ use rust_lstm::{ ScheduledLSTMTrainer, ScheduledOptimizer, StepLR, TrainingConfig, WarmupScheduler, }; +pub const DEMO_TRAIN_SEQUENCES: usize = 12; +pub const DEMO_VAL_SEQUENCES: usize = 4; +pub const DEMO_SEQUENCE_LENGTH: usize = 6; +pub const DEMO_HIDDEN_SIZE: usize = 4; +pub const DEMO_ADVANCED_HIDDEN_SIZE: usize = 6; +pub const DEMO_POLYNOMIAL_EPOCHS: usize = 5; +pub const DEMO_CYCLICAL_EPOCHS: usize = 5; +pub const DEMO_WARMUP_EPOCHS: usize = 5; +pub const DEMO_ADVANCED_EPOCHS: usize = 5; +pub const DEMO_POLYNOMIAL_ITERS: usize = DEMO_POLYNOMIAL_EPOCHS; +pub const DEMO_CYCLICAL_STEP_SIZE: usize = 2; +pub const DEMO_WARMUP_EPOCH_COUNT: usize = 2; +pub const DEMO_BASE_STEP_SIZE: usize = 2; +pub const DEMO_VISUALIZATION_STEP_SIZE: usize = 2; +pub const DEMO_VISUALIZATION_STEPS: usize = 20; + fn main() { println!("🚀 Advanced Learning Rate Scheduling for Rust-LSTM"); println!("===================================================\n"); // Generate sample training data - let train_data = generate_sine_wave_data(50, 0.0); - let val_data = generate_sine_wave_data(10, 1000.0); + let train_data = generate_sine_wave_data(DEMO_TRAIN_SEQUENCES, 0.0); + let val_data = generate_sine_wave_data(DEMO_VAL_SEQUENCES, 1000.0); // 1. Polynomial Decay Example polynomial_decay_example(&train_data, &val_data); @@ -43,24 +59,18 @@ fn polynomial_decay_example( println!("1️⃣ Polynomial Decay Example"); println!(" Smoothly decays LR using polynomial function\n"); - let network = LSTMNetwork::new(1, 8, 1); + let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 1); let loss_function = MSELoss; let scheduled_optimizer = ScheduledOptimizer::polynomial( Adam::new(0.01), - 0.01, // base_lr - 25, // total_iters - 2.0, // power - 0.001, // end_lr + 0.01, // base_lr + DEMO_POLYNOMIAL_ITERS, // total_iters + 2.0, // power + 0.001, // end_lr ); - let config = TrainingConfig { - epochs: 30, - print_every: 5, - clip_gradient: Some(1.0), - log_lr_changes: true, - early_stopping: None, - }; + let config = polynomial_decay_training_config(); let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config); @@ -80,23 +90,17 @@ fn cyclical_lr_examples( // 2a. Triangular Cyclical LR println!("2a. Triangular Cyclical LR"); - let network = LSTMNetwork::new(1, 8, 1); + let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 1); let loss_function = MSELoss; let scheduled_optimizer = ScheduledOptimizer::cyclical( Adam::new(0.001), - 0.001, // base_lr - 0.01, // max_lr - 8, // step_size + 0.001, // base_lr + 0.01, // max_lr + DEMO_CYCLICAL_STEP_SIZE, // step_size ); - let config = TrainingConfig { - epochs: 25, - print_every: 5, - clip_gradient: Some(1.0), - log_lr_changes: false, // Too frequent for cyclical - early_stopping: None, - }; + let config = cyclical_lr_training_config(); let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config); @@ -106,23 +110,17 @@ fn cyclical_lr_examples( // 2b. Triangular2 Cyclical LR (halving amplitude each cycle) println!("2b. Triangular2 Cyclical LR (halving amplitude each cycle)"); - let network = LSTMNetwork::new(1, 8, 1); + let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 1); let loss_function = MSELoss; let scheduled_optimizer = ScheduledOptimizer::cyclical_triangular2( Adam::new(0.001), - 0.001, // base_lr - 0.01, // max_lr - 8, // step_size + 0.001, // base_lr + 0.01, // max_lr + DEMO_CYCLICAL_STEP_SIZE, // step_size ); - let config2 = TrainingConfig { - epochs: 25, - print_every: 5, - clip_gradient: Some(1.0), - log_lr_changes: false, - early_stopping: None, - }; + let config2 = cyclical_lr_training_config(); let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config2); @@ -132,24 +130,18 @@ fn cyclical_lr_examples( // 2c. ExpRange Cyclical LR (exponential scaling) println!("2c. ExpRange Cyclical LR (exponential scaling)"); - let network = LSTMNetwork::new(1, 8, 1); + let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 1); let loss_function = MSELoss; let scheduled_optimizer = ScheduledOptimizer::cyclical_exp_range( Adam::new(0.001), - 0.001, // base_lr - 0.01, // max_lr - 8, // step_size - 0.95, // gamma + 0.001, // base_lr + 0.01, // max_lr + DEMO_CYCLICAL_STEP_SIZE, // step_size + 0.95, // gamma ); - let config3 = TrainingConfig { - epochs: 25, - print_every: 5, - clip_gradient: Some(1.0), - log_lr_changes: false, - early_stopping: None, - }; + let config3 = cyclical_lr_training_config(); let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config3); @@ -167,26 +159,15 @@ fn warmup_scheduler_example( println!("3️⃣ Warmup Scheduler Example"); println!(" Gradually increases LR during warmup, then applies base scheduler\n"); - let network = LSTMNetwork::new(1, 8, 1); + let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 1); - // Create warmup scheduler with step decay after warmup - let base_scheduler = StepLR::new(10, 0.5); // Reduce by half every 10 epochs - let warmup_scheduler = WarmupScheduler::new( - 5, // warmup_epochs - base_scheduler, // base_scheduler - 0.001, // warmup_start_lr - ); + let base_scheduler = StepLR::new(DEMO_BASE_STEP_SIZE, 0.5); + let warmup_scheduler = WarmupScheduler::new(DEMO_WARMUP_EPOCH_COUNT, base_scheduler, 0.001); let loss_function = MSELoss; let scheduled_optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01); - let config = TrainingConfig { - epochs: 30, - print_every: 3, - clip_gradient: Some(1.0), - log_lr_changes: true, - early_stopping: None, - }; + let config = warmup_scheduler_training_config(); let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config); @@ -202,21 +183,27 @@ fn schedule_visualization() { println!(" ASCII visualization of different schedulers\n"); // Visualize StepLR - println!("StepLR (step_size=10, gamma=0.5):"); - let step_scheduler = StepLR::new(10, 0.5); - LRScheduleVisualizer::print_schedule(step_scheduler, 0.01, 50, 60, 10); + println!("StepLR (step_size=2, gamma=0.5):"); + let step_scheduler = StepLR::new(DEMO_VISUALIZATION_STEP_SIZE, 0.5); + LRScheduleVisualizer::print_schedule(step_scheduler, 0.01, DEMO_VISUALIZATION_STEPS, 40, 5); println!(); // Visualize PolynomialLR println!("PolynomialLR (power=2.0, end_lr=0.001):"); - let poly_scheduler = PolynomialLR::new(50, 2.0, 0.001); - LRScheduleVisualizer::print_schedule(poly_scheduler, 0.01, 50, 60, 10); + let poly_scheduler = PolynomialLR::new(DEMO_VISUALIZATION_STEPS, 2.0, 0.001); + LRScheduleVisualizer::print_schedule(poly_scheduler, 0.01, DEMO_VISUALIZATION_STEPS, 40, 5); println!(); // Visualize CyclicalLR - println!("CyclicalLR Triangular (base_lr=0.001, max_lr=0.01, step_size=8):"); - let cyclical_scheduler = CyclicalLR::new(0.001, 0.01, 8); - LRScheduleVisualizer::print_schedule(cyclical_scheduler, 0.001, 50, 60, 10); + println!("CyclicalLR Triangular (base_lr=0.001, max_lr=0.01, step_size=2):"); + let cyclical_scheduler = CyclicalLR::new(0.001, 0.01, DEMO_CYCLICAL_STEP_SIZE); + LRScheduleVisualizer::print_schedule( + cyclical_scheduler, + 0.001, + DEMO_VISUALIZATION_STEPS, + 40, + 5, + ); println!(); println!("----------------------------------------\n"); @@ -230,25 +217,20 @@ fn advanced_training_example( println!(" Warmup + Cyclical LR + Dropout + Gradient Clipping\n"); // Create network with dropout - let network = LSTMNetwork::new(1, 16, 1) + let network = LSTMNetwork::new(1, DEMO_ADVANCED_HIDDEN_SIZE, 1) .with_input_dropout(0.1, true) // Variational dropout .with_recurrent_dropout(0.2, true) // Variational recurrent dropout .with_output_dropout(0.1); // Standard output dropout // Create warmup scheduler with cyclical base scheduler - let base_scheduler = CyclicalLR::new(0.001, 0.01, 10).with_mode(CyclicalMode::Triangular2); - let warmup_scheduler = WarmupScheduler::new(5, base_scheduler, 0.0001); + let base_scheduler = + CyclicalLR::new(0.001, 0.01, DEMO_CYCLICAL_STEP_SIZE).with_mode(CyclicalMode::Triangular2); + let warmup_scheduler = WarmupScheduler::new(DEMO_WARMUP_EPOCH_COUNT, base_scheduler, 0.0001); let loss_function = MSELoss; let scheduled_optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01); - let config = TrainingConfig { - epochs: 40, - print_every: 5, - clip_gradient: Some(1.0), // Gradient clipping - log_lr_changes: false, // Too frequent for cyclical - early_stopping: None, - }; + let config = advanced_training_config(); let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config); @@ -272,18 +254,17 @@ fn advanced_training_example( println!("\n✅ Advanced training complete!"); } -fn generate_sine_wave_data( +pub fn generate_sine_wave_data( num_sequences: usize, offset: f64, ) -> Vec<(Vec>, Vec>)> { let mut data = Vec::new(); for i in 0..num_sequences { - let sequence_length = 8; let mut inputs = Vec::new(); let mut targets = Vec::new(); - for t in 0..sequence_length { + for t in 0..DEMO_SEQUENCE_LENGTH { let x = (offset + i as f64 * 0.1 + t as f64 * 0.2).sin(); let y = (offset + i as f64 * 0.1 + (t + 1) as f64 * 0.2).sin(); @@ -297,19 +278,59 @@ fn generate_sine_wave_data( data } +pub fn polynomial_decay_training_config() -> TrainingConfig { + TrainingConfig { + epochs: DEMO_POLYNOMIAL_EPOCHS, + print_every: 1, + clip_gradient: Some(1.0), + log_lr_changes: true, + early_stopping: None, + } +} + +pub fn cyclical_lr_training_config() -> TrainingConfig { + TrainingConfig { + epochs: DEMO_CYCLICAL_EPOCHS, + print_every: 1, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: None, + } +} + +pub fn warmup_scheduler_training_config() -> TrainingConfig { + TrainingConfig { + epochs: DEMO_WARMUP_EPOCHS, + print_every: 1, + clip_gradient: Some(1.0), + log_lr_changes: true, + early_stopping: None, + } +} + +pub fn advanced_training_config() -> TrainingConfig { + TrainingConfig { + epochs: DEMO_ADVANCED_EPOCHS, + print_every: 1, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: None, + } +} + #[cfg(test)] mod tests { use super::*; - use rust_lstm::SGD; #[test] fn test_advanced_schedulers() { // Test polynomial scheduler let poly_scheduler = PolynomialLR::new(100, 2.0, 0.01); - let schedule = LRScheduleVisualizer::generate_schedule(poly_scheduler, 0.1, 100); - assert_eq!(schedule.len(), 100); + let schedule = LRScheduleVisualizer::generate_schedule(poly_scheduler, 0.1, 101); + assert_eq!(schedule.len(), 101); assert_eq!(schedule[0].1, 0.1); - assert!((schedule[99].1 - 0.01).abs() < 1e-10); + assert!(schedule[99].1 < schedule[0].1); + assert!((schedule[100].1 - 0.01).abs() < 1e-10); // Test cyclical scheduler let cyclical_scheduler = CyclicalLR::new(0.01, 0.1, 10); diff --git a/tests/example_training_bounds_test.rs b/tests/example_training_bounds_test.rs index ea3b7d0..8cfeff1 100644 --- a/tests/example_training_bounds_test.rs +++ b/tests/example_training_bounds_test.rs @@ -14,6 +14,10 @@ mod early_stopping_example; #[path = "../examples/learning_rate_scheduling.rs"] mod learning_rate_scheduling; +#[allow(dead_code)] +#[path = "../examples/advanced_lr_scheduling.rs"] +mod advanced_lr_scheduling; + use std::hint::black_box; #[test] @@ -229,3 +233,95 @@ fn learning_rate_scheduling_uses_small_deterministic_fixture() { "generated fixtures should use the public demo sequence-length bound" ); } + +#[test] +fn advanced_lr_scheduling_applies_bounded_configs_to_all_demo_paths() { + let configs = [ + advanced_lr_scheduling::polynomial_decay_training_config(), + advanced_lr_scheduling::cyclical_lr_training_config(), + advanced_lr_scheduling::warmup_scheduler_training_config(), + advanced_lr_scheduling::advanced_training_config(), + ]; + + for config in configs { + assert!( + black_box(config.epochs) <= 5, + "advanced_lr_scheduling should keep every demo training path bounded" + ); + assert!( + black_box(config.print_every) > 0, + "advanced_lr_scheduling progress logging should stay enabled" + ); + assert!( + black_box(config.print_every) <= black_box(config.epochs), + "advanced_lr_scheduling progress logging should not exceed the epoch budget" + ); + assert!( + config.early_stopping.is_none(), + "advanced_lr_scheduling examples should avoid hidden early-stopping work" + ); + } + + assert!( + black_box(advanced_lr_scheduling::DEMO_HIDDEN_SIZE) <= 4, + "advanced_lr_scheduling should keep demo hidden size bounded" + ); + assert!( + black_box(advanced_lr_scheduling::DEMO_ADVANCED_HIDDEN_SIZE) <= 6, + "advanced_lr_scheduling should keep advanced demo hidden size bounded" + ); + assert!( + black_box(advanced_lr_scheduling::DEMO_POLYNOMIAL_ITERS) + <= black_box(advanced_lr_scheduling::DEMO_POLYNOMIAL_EPOCHS), + "polynomial scheduler iterations should fit inside the demo epoch budget" + ); + assert!( + black_box(advanced_lr_scheduling::DEMO_CYCLICAL_STEP_SIZE) + <= black_box(advanced_lr_scheduling::DEMO_CYCLICAL_EPOCHS), + "cyclical step size should fit inside the demo epoch budget" + ); + assert!( + black_box(advanced_lr_scheduling::DEMO_WARMUP_EPOCH_COUNT) + <= black_box(advanced_lr_scheduling::DEMO_WARMUP_EPOCHS), + "warmup period should fit inside the demo epoch budget" + ); + assert!( + black_box(advanced_lr_scheduling::DEMO_VISUALIZATION_STEP_SIZE) + <= black_box(advanced_lr_scheduling::DEMO_VISUALIZATION_STEPS), + "visualization scheduler step size should fit inside the visualization budget" + ); + assert!( + black_box(advanced_lr_scheduling::DEMO_VISUALIZATION_STEPS) <= 20, + "schedule visualization should stay bounded for interactive runs" + ); +} + +#[test] +fn advanced_lr_scheduling_uses_small_deterministic_fixture() { + let first = advanced_lr_scheduling::generate_sine_wave_data(black_box(4), 0.0); + let second = advanced_lr_scheduling::generate_sine_wave_data(black_box(4), 0.0); + + assert_eq!( + first, second, + "advanced LR demo fixture should be deterministic" + ); + assert!( + black_box(advanced_lr_scheduling::DEMO_TRAIN_SEQUENCES) <= 12, + "advanced_lr_scheduling should keep training sequence count bounded" + ); + assert!( + black_box(advanced_lr_scheduling::DEMO_VAL_SEQUENCES) <= 4, + "advanced_lr_scheduling should keep validation sequence count bounded" + ); + assert!( + black_box(advanced_lr_scheduling::DEMO_SEQUENCE_LENGTH) <= 6, + "advanced_lr_scheduling should keep each fixture sequence bounded" + ); + assert!( + first.iter().all(|(inputs, targets)| { + inputs.len() == advanced_lr_scheduling::DEMO_SEQUENCE_LENGTH + && targets.len() == advanced_lr_scheduling::DEMO_SEQUENCE_LENGTH + }), + "generated fixtures should use the public demo sequence-length bound" + ); +}