diff --git a/examples/learning_rate_scheduling.rs b/examples/learning_rate_scheduling.rs index db488a4..7394c81 100644 --- a/examples/learning_rate_scheduling.rs +++ b/examples/learning_rate_scheduling.rs @@ -13,13 +13,93 @@ use rust_lstm::{ ScheduledOptimizer, TrainingConfig, }; +pub const DEMO_TRAIN_SEQUENCES: usize = 16; +pub const DEMO_VAL_SEQUENCES: usize = 4; +pub const DEMO_SEQUENCE_LENGTH: usize = 6; +pub const DEMO_HIDDEN_SIZE: usize = 4; +pub const DEMO_COMPARISON_HIDDEN_SIZE: usize = 3; + +pub const DEMO_STEP_EPOCHS: usize = 4; +pub const DEMO_ONE_CYCLE_EPOCHS: usize = 5; +pub const DEMO_COSINE_EPOCHS: usize = 4; +pub const DEMO_EXPONENTIAL_EPOCHS: usize = 4; +pub const DEMO_PLATEAU_EPOCHS: usize = 5; +pub const DEMO_COMPARISON_EPOCHS: usize = 3; + +pub const DEMO_PRINT_EVERY: usize = 1; +pub const DEMO_STEP_PERIOD: usize = 2; +pub const DEMO_COSINE_PERIOD: usize = 3; +pub const DEMO_PLATEAU_PATIENCE: usize = 2; +/// A large step period keeps this comparison path effectively constant. +pub const DEMO_CONSTANT_STEP_PERIOD: usize = 1_000; + +pub fn step_lr_training_config() -> TrainingConfig { + TrainingConfig { + epochs: DEMO_STEP_EPOCHS, + print_every: DEMO_PRINT_EVERY, + clip_gradient: Some(1.0), + log_lr_changes: true, + early_stopping: None, + } +} + +pub fn one_cycle_training_config() -> TrainingConfig { + TrainingConfig { + epochs: DEMO_ONE_CYCLE_EPOCHS, + print_every: DEMO_PRINT_EVERY, + clip_gradient: Some(1.0), + log_lr_changes: false, // Too many changes for OneCycle + early_stopping: None, + } +} + +pub fn cosine_annealing_training_config() -> TrainingConfig { + TrainingConfig { + epochs: DEMO_COSINE_EPOCHS, + print_every: DEMO_PRINT_EVERY, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: None, + } +} + +pub fn exponential_decay_training_config() -> TrainingConfig { + TrainingConfig { + epochs: DEMO_EXPONENTIAL_EPOCHS, + print_every: DEMO_PRINT_EVERY, + clip_gradient: Some(1.0), + log_lr_changes: true, + early_stopping: None, + } +} + +pub fn reduce_on_plateau_training_config() -> TrainingConfig { + TrainingConfig { + epochs: DEMO_PLATEAU_EPOCHS, + print_every: DEMO_PRINT_EVERY, + clip_gradient: Some(1.0), + log_lr_changes: true, + early_stopping: None, + } +} + +pub fn scheduler_comparison_training_config() -> TrainingConfig { + TrainingConfig { + epochs: DEMO_COMPARISON_EPOCHS, + print_every: DEMO_COMPARISON_EPOCHS, // Only print final result + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: None, + } +} + fn main() { println!("Learning Rate Scheduling Examples for Rust-LSTM"); println!("==================================================\n"); // Generate sample training data (sine wave prediction) - let train_data = generate_sine_wave_data(100, 0.0); - let val_data = generate_sine_wave_data(20, 1000.0); + let train_data = generate_sine_wave_data(DEMO_TRAIN_SEQUENCES, 0.0); + let val_data = generate_sine_wave_data(DEMO_VAL_SEQUENCES, 1000.0); // Example 1: Step Learning Rate Decay step_lr_example(&train_data, &val_data); @@ -45,21 +125,16 @@ fn step_lr_example( val_data: &[(Vec>, Vec>)], ) { println!("Step Learning Rate Decay Example"); - println!("Reduces LR by factor of 0.5 every 10 epochs\n"); + println!("Reduces LR by factor of 0.5 every {DEMO_STEP_PERIOD} epochs\n"); - let network = LSTMNetwork::new(1, 10, 2) + let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 2) .with_input_dropout(0.1, false) .with_recurrent_dropout(0.2, true); - let config = TrainingConfig { - epochs: 30, - print_every: 5, - clip_gradient: Some(1.0), - log_lr_changes: true, - early_stopping: None, - }; + let config = step_lr_training_config(); - let mut trainer = create_step_lr_trainer(network, 0.01, 10, 0.5).with_config(config); + let mut trainer = + create_step_lr_trainer(network, 0.01, DEMO_STEP_PERIOD, 0.5).with_config(config); trainer.train(train_data, Some(val_data)); @@ -74,17 +149,12 @@ fn one_cycle_example( println!("OneCycle Learning Rate Policy Example"); println!("Starts low, ramps up to max, then anneals down\n"); - let network = LSTMNetwork::new(1, 10, 2); + let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 2); - let config = TrainingConfig { - epochs: 50, - print_every: 10, - clip_gradient: Some(1.0), - log_lr_changes: false, // Too many changes for OneCycle - early_stopping: None, - }; + let config = one_cycle_training_config(); - let mut trainer = create_one_cycle_trainer(network, 0.1, 50).with_config(config); + let mut trainer = + create_one_cycle_trainer(network, 0.1, DEMO_ONE_CYCLE_EPOCHS).with_config(config); trainer.train(train_data, Some(val_data)); @@ -99,17 +169,12 @@ fn cosine_annealing_example( println!("Cosine Annealing Example"); println!("Smoothly oscillates LR following cosine curve\n"); - let network = LSTMNetwork::new(1, 10, 2); + let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 2); - let config = TrainingConfig { - epochs: 40, - print_every: 8, - clip_gradient: Some(1.0), - log_lr_changes: false, - early_stopping: None, - }; + let config = cosine_annealing_training_config(); - let mut trainer = create_cosine_annealing_trainer(network, 0.01, 20, 1e-6).with_config(config); + let mut trainer = create_cosine_annealing_trainer(network, 0.01, DEMO_COSINE_PERIOD, 1e-6) + .with_config(config); trainer.train(train_data, Some(val_data)); @@ -124,18 +189,12 @@ fn exponential_decay_example( println!("Exponential Decay Example"); println!("Continuously decays LR by factor of 0.95 each epoch\n"); - let network = LSTMNetwork::new(1, 10, 2); + let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 2); let loss_function = MSELoss; let scheduled_optimizer = ScheduledOptimizer::exponential(Adam::new(0.01), 0.01, 0.95); - let config = TrainingConfig { - epochs: 30, - print_every: 6, - clip_gradient: Some(1.0), - log_lr_changes: true, - early_stopping: None, - }; + let config = exponential_decay_training_config(); let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config); @@ -153,20 +212,14 @@ fn reduce_on_plateau_example( println!("ReduceLROnPlateau Example"); println!("Reduces LR when validation loss stops improving\n"); - let _network = LSTMNetwork::new(1, 10, 2); + let _network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 2); // Create a plateau scheduler manually since we need special handling - let mut plateau_scheduler = ReduceLROnPlateau::new(0.5, 5); + let mut plateau_scheduler = ReduceLROnPlateau::new(0.5, DEMO_PLATEAU_PATIENCE); let mut optimizer = Adam::new(0.01); let _loss_function = MSELoss; - let config = TrainingConfig { - epochs: 40, - print_every: 5, - clip_gradient: Some(1.0), - log_lr_changes: true, - early_stopping: None, - }; + let config = reduce_on_plateau_training_config(); println!("Training with manual ReduceLROnPlateau stepping..."); @@ -211,20 +264,15 @@ fn scheduler_comparison( for (name, scheduler_type) in schedulers { println!("Testing {} scheduler:", name); - let network = LSTMNetwork::new(1, 8, 1); // Smaller network for faster comparison + let network = LSTMNetwork::new(1, DEMO_COMPARISON_HIDDEN_SIZE, 1); // Smaller network for faster comparison - let config = TrainingConfig { - epochs: 20, - print_every: 20, // Only print final result - clip_gradient: Some(1.0), - log_lr_changes: false, - early_stopping: None, - }; + let config = scheduler_comparison_training_config(); let final_loss = match scheduler_type { "constant" => { - let mut trainer = create_step_lr_trainer(network, 0.01, 1000, 1.0) // Effectively constant - .with_config(config); + let mut trainer = + create_step_lr_trainer(network, 0.01, DEMO_CONSTANT_STEP_PERIOD, 1.0) // Effectively constant + .with_config(config); trainer.train(train_data, Some(val_data)); trainer .get_latest_metrics() @@ -233,8 +281,8 @@ fn scheduler_comparison( .unwrap_or(0.0) } "step" => { - let mut trainer = - create_step_lr_trainer(network, 0.01, 10, 0.5).with_config(config); + let mut trainer = create_step_lr_trainer(network, 0.01, DEMO_STEP_PERIOD, 0.5) + .with_config(config); trainer.train(train_data, Some(val_data)); trainer .get_latest_metrics() @@ -257,7 +305,8 @@ fn scheduler_comparison( .unwrap_or(0.0) } "onecycle" => { - let mut trainer = create_one_cycle_trainer(network, 0.05, 20).with_config(config); + let mut trainer = create_one_cycle_trainer(network, 0.05, DEMO_COMPARISON_EPOCHS) + .with_config(config); trainer.train(train_data, Some(val_data)); trainer .get_latest_metrics() @@ -274,14 +323,14 @@ fn scheduler_comparison( println!("Comparison complete! Check which scheduler performed best."); } -fn generate_sine_wave_data( +pub fn generate_sine_wave_data( num_sequences: usize, offset: f64, ) -> Vec<(Vec>, Vec>)> { let mut data = Vec::new(); for i in 0..num_sequences { - let sequence_length = 10; + let sequence_length = DEMO_SEQUENCE_LENGTH; let mut inputs = Vec::new(); let mut targets = Vec::new(); diff --git a/tests/example_training_bounds_test.rs b/tests/example_training_bounds_test.rs index a5ab898..ea3b7d0 100644 --- a/tests/example_training_bounds_test.rs +++ b/tests/example_training_bounds_test.rs @@ -10,6 +10,10 @@ mod stock_prediction; #[path = "../examples/early_stopping_example.rs"] mod early_stopping_example; +#[allow(dead_code)] +#[path = "../examples/learning_rate_scheduling.rs"] +mod learning_rate_scheduling; + use std::hint::black_box; #[test] @@ -128,3 +132,100 @@ fn early_stopping_example_uses_small_deterministic_fixture() { "early_stopping_example should keep each fixture sequence bounded" ); } + +#[test] +fn learning_rate_scheduling_applies_bounded_configs_to_all_demo_paths() { + let configs = [ + learning_rate_scheduling::step_lr_training_config(), + learning_rate_scheduling::one_cycle_training_config(), + learning_rate_scheduling::cosine_annealing_training_config(), + learning_rate_scheduling::exponential_decay_training_config(), + learning_rate_scheduling::reduce_on_plateau_training_config(), + learning_rate_scheduling::scheduler_comparison_training_config(), + ]; + + for config in configs { + assert!( + black_box(config.epochs) <= 5, + "learning_rate_scheduling should keep every demo training path bounded" + ); + assert!( + black_box(config.print_every) > 0, + "learning_rate_scheduling progress logging should stay enabled" + ); + assert!( + black_box(config.print_every) <= black_box(config.epochs), + "learning_rate_scheduling progress logging should not exceed the epoch budget" + ); + assert!( + config.early_stopping.is_none(), + "learning_rate_scheduling examples should avoid hidden early-stopping work" + ); + } + + assert!( + black_box(learning_rate_scheduling::DEMO_HIDDEN_SIZE) <= 4, + "learning_rate_scheduling should keep demo hidden size bounded" + ); + assert!( + black_box(learning_rate_scheduling::DEMO_COMPARISON_HIDDEN_SIZE) <= 4, + "learning_rate_scheduling should keep comparison hidden size bounded" + ); + assert!( + black_box(learning_rate_scheduling::DEMO_STEP_PERIOD) > 0, + "step scheduler period should be non-zero" + ); + assert!( + black_box(learning_rate_scheduling::DEMO_STEP_PERIOD) + <= black_box(learning_rate_scheduling::DEMO_STEP_EPOCHS), + "step scheduler period should fit inside the demo epoch budget" + ); + assert!( + black_box(learning_rate_scheduling::DEMO_COSINE_PERIOD) > 0, + "cosine scheduler period should be non-zero" + ); + assert!( + black_box(learning_rate_scheduling::DEMO_COSINE_PERIOD) + <= black_box(learning_rate_scheduling::DEMO_COSINE_EPOCHS), + "cosine scheduler period should fit inside the demo epoch budget" + ); + assert!( + black_box(learning_rate_scheduling::DEMO_PLATEAU_PATIENCE) > 0, + "plateau patience should be non-zero" + ); + assert!( + black_box(learning_rate_scheduling::DEMO_PLATEAU_PATIENCE) + <= black_box(learning_rate_scheduling::DEMO_PLATEAU_EPOCHS), + "plateau patience should fit inside the demo epoch budget" + ); +} + +#[test] +fn learning_rate_scheduling_uses_small_deterministic_fixture() { + let first = learning_rate_scheduling::generate_sine_wave_data(black_box(4), 0.0); + let second = learning_rate_scheduling::generate_sine_wave_data(black_box(4), 0.0); + + assert_eq!( + first, second, + "demo sine-wave fixture should be deterministic" + ); + assert!( + black_box(learning_rate_scheduling::DEMO_TRAIN_SEQUENCES) <= 16, + "learning_rate_scheduling should keep training sequence count bounded" + ); + assert!( + black_box(learning_rate_scheduling::DEMO_VAL_SEQUENCES) <= 4, + "learning_rate_scheduling should keep validation sequence count bounded" + ); + assert!( + black_box(learning_rate_scheduling::DEMO_SEQUENCE_LENGTH) <= 6, + "learning_rate_scheduling should keep each fixture sequence bounded" + ); + assert!( + first.iter().all(|(inputs, targets)| { + inputs.len() == learning_rate_scheduling::DEMO_SEQUENCE_LENGTH + && targets.len() == learning_rate_scheduling::DEMO_SEQUENCE_LENGTH + }), + "generated fixtures should use the public demo sequence-length bound" + ); +}