Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 111 additions & 62 deletions examples/learning_rate_scheduling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,93 @@ use rust_lstm::{
ScheduledOptimizer, TrainingConfig,
};

pub const DEMO_TRAIN_SEQUENCES: usize = 16;
pub const DEMO_VAL_SEQUENCES: usize = 4;
pub const DEMO_SEQUENCE_LENGTH: usize = 6;
pub const DEMO_HIDDEN_SIZE: usize = 4;
pub const DEMO_COMPARISON_HIDDEN_SIZE: usize = 3;

pub const DEMO_STEP_EPOCHS: usize = 4;
pub const DEMO_ONE_CYCLE_EPOCHS: usize = 5;
pub const DEMO_COSINE_EPOCHS: usize = 4;
pub const DEMO_EXPONENTIAL_EPOCHS: usize = 4;
pub const DEMO_PLATEAU_EPOCHS: usize = 5;
pub const DEMO_COMPARISON_EPOCHS: usize = 3;

pub const DEMO_PRINT_EVERY: usize = 1;
pub const DEMO_STEP_PERIOD: usize = 2;
pub const DEMO_COSINE_PERIOD: usize = 3;
pub const DEMO_PLATEAU_PATIENCE: usize = 2;
/// A large step period keeps this comparison path effectively constant.
pub const DEMO_CONSTANT_STEP_PERIOD: usize = 1_000;

pub fn step_lr_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_STEP_EPOCHS,
print_every: DEMO_PRINT_EVERY,
clip_gradient: Some(1.0),
log_lr_changes: true,
early_stopping: None,
}
}

pub fn one_cycle_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_ONE_CYCLE_EPOCHS,
print_every: DEMO_PRINT_EVERY,
clip_gradient: Some(1.0),
log_lr_changes: false, // Too many changes for OneCycle
early_stopping: None,
}
}

pub fn cosine_annealing_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_COSINE_EPOCHS,
print_every: DEMO_PRINT_EVERY,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: None,
}
}

pub fn exponential_decay_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_EXPONENTIAL_EPOCHS,
print_every: DEMO_PRINT_EVERY,
clip_gradient: Some(1.0),
log_lr_changes: true,
early_stopping: None,
}
}

pub fn reduce_on_plateau_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_PLATEAU_EPOCHS,
print_every: DEMO_PRINT_EVERY,
clip_gradient: Some(1.0),
log_lr_changes: true,
early_stopping: None,
}
}

pub fn scheduler_comparison_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_COMPARISON_EPOCHS,
print_every: DEMO_COMPARISON_EPOCHS, // Only print final result
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: None,
}
}

fn main() {
println!("Learning Rate Scheduling Examples for Rust-LSTM");
println!("==================================================\n");

// Generate sample training data (sine wave prediction)
let train_data = generate_sine_wave_data(100, 0.0);
let val_data = generate_sine_wave_data(20, 1000.0);
let train_data = generate_sine_wave_data(DEMO_TRAIN_SEQUENCES, 0.0);
let val_data = generate_sine_wave_data(DEMO_VAL_SEQUENCES, 1000.0);

// Example 1: Step Learning Rate Decay
step_lr_example(&train_data, &val_data);
Expand All @@ -45,21 +125,16 @@ fn step_lr_example(
val_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
) {
println!("Step Learning Rate Decay Example");
println!("Reduces LR by factor of 0.5 every 10 epochs\n");
println!("Reduces LR by factor of 0.5 every {DEMO_STEP_PERIOD} epochs\n");

let network = LSTMNetwork::new(1, 10, 2)
let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 2)
.with_input_dropout(0.1, false)
.with_recurrent_dropout(0.2, true);

let config = TrainingConfig {
epochs: 30,
print_every: 5,
clip_gradient: Some(1.0),
log_lr_changes: true,
early_stopping: None,
};
let config = step_lr_training_config();

let mut trainer = create_step_lr_trainer(network, 0.01, 10, 0.5).with_config(config);
let mut trainer =
create_step_lr_trainer(network, 0.01, DEMO_STEP_PERIOD, 0.5).with_config(config);

trainer.train(train_data, Some(val_data));

Expand All @@ -74,17 +149,12 @@ fn one_cycle_example(
println!("OneCycle Learning Rate Policy Example");
println!("Starts low, ramps up to max, then anneals down\n");

let network = LSTMNetwork::new(1, 10, 2);
let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 2);

let config = TrainingConfig {
epochs: 50,
print_every: 10,
clip_gradient: Some(1.0),
log_lr_changes: false, // Too many changes for OneCycle
early_stopping: None,
};
let config = one_cycle_training_config();

let mut trainer = create_one_cycle_trainer(network, 0.1, 50).with_config(config);
let mut trainer =
create_one_cycle_trainer(network, 0.1, DEMO_ONE_CYCLE_EPOCHS).with_config(config);

trainer.train(train_data, Some(val_data));

Expand All @@ -99,17 +169,12 @@ fn cosine_annealing_example(
println!("Cosine Annealing Example");
println!("Smoothly oscillates LR following cosine curve\n");

let network = LSTMNetwork::new(1, 10, 2);
let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 2);

let config = TrainingConfig {
epochs: 40,
print_every: 8,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: None,
};
let config = cosine_annealing_training_config();

let mut trainer = create_cosine_annealing_trainer(network, 0.01, 20, 1e-6).with_config(config);
let mut trainer = create_cosine_annealing_trainer(network, 0.01, DEMO_COSINE_PERIOD, 1e-6)
.with_config(config);

trainer.train(train_data, Some(val_data));

Expand All @@ -124,18 +189,12 @@ fn exponential_decay_example(
println!("Exponential Decay Example");
println!("Continuously decays LR by factor of 0.95 each epoch\n");

let network = LSTMNetwork::new(1, 10, 2);
let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 2);

let loss_function = MSELoss;
let scheduled_optimizer = ScheduledOptimizer::exponential(Adam::new(0.01), 0.01, 0.95);

let config = TrainingConfig {
epochs: 30,
print_every: 6,
clip_gradient: Some(1.0),
log_lr_changes: true,
early_stopping: None,
};
let config = exponential_decay_training_config();

let mut trainer =
ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config);
Expand All @@ -153,20 +212,14 @@ fn reduce_on_plateau_example(
println!("ReduceLROnPlateau Example");
println!("Reduces LR when validation loss stops improving\n");

let _network = LSTMNetwork::new(1, 10, 2);
let _network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 2);

// Create a plateau scheduler manually since we need special handling
let mut plateau_scheduler = ReduceLROnPlateau::new(0.5, 5);
let mut plateau_scheduler = ReduceLROnPlateau::new(0.5, DEMO_PLATEAU_PATIENCE);
let mut optimizer = Adam::new(0.01);
let _loss_function = MSELoss;

let config = TrainingConfig {
epochs: 40,
print_every: 5,
clip_gradient: Some(1.0),
log_lr_changes: true,
early_stopping: None,
};
let config = reduce_on_plateau_training_config();

println!("Training with manual ReduceLROnPlateau stepping...");

Expand Down Expand Up @@ -211,20 +264,15 @@ fn scheduler_comparison(
for (name, scheduler_type) in schedulers {
println!("Testing {} scheduler:", name);

let network = LSTMNetwork::new(1, 8, 1); // Smaller network for faster comparison
let network = LSTMNetwork::new(1, DEMO_COMPARISON_HIDDEN_SIZE, 1); // Smaller network for faster comparison

let config = TrainingConfig {
epochs: 20,
print_every: 20, // Only print final result
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: None,
};
let config = scheduler_comparison_training_config();

let final_loss = match scheduler_type {
"constant" => {
let mut trainer = create_step_lr_trainer(network, 0.01, 1000, 1.0) // Effectively constant
.with_config(config);
let mut trainer =
create_step_lr_trainer(network, 0.01, DEMO_CONSTANT_STEP_PERIOD, 1.0) // Effectively constant
.with_config(config);
trainer.train(train_data, Some(val_data));
trainer
.get_latest_metrics()
Expand All @@ -233,8 +281,8 @@ fn scheduler_comparison(
.unwrap_or(0.0)
}
"step" => {
let mut trainer =
create_step_lr_trainer(network, 0.01, 10, 0.5).with_config(config);
let mut trainer = create_step_lr_trainer(network, 0.01, DEMO_STEP_PERIOD, 0.5)
.with_config(config);
trainer.train(train_data, Some(val_data));
trainer
.get_latest_metrics()
Expand All @@ -257,7 +305,8 @@ fn scheduler_comparison(
.unwrap_or(0.0)
}
"onecycle" => {
let mut trainer = create_one_cycle_trainer(network, 0.05, 20).with_config(config);
let mut trainer = create_one_cycle_trainer(network, 0.05, DEMO_COMPARISON_EPOCHS)
.with_config(config);
trainer.train(train_data, Some(val_data));
trainer
.get_latest_metrics()
Expand All @@ -274,14 +323,14 @@ fn scheduler_comparison(
println!("Comparison complete! Check which scheduler performed best.");
}

fn generate_sine_wave_data(
pub fn generate_sine_wave_data(
num_sequences: usize,
offset: f64,
) -> Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)> {
let mut data = Vec::new();

for i in 0..num_sequences {
let sequence_length = 10;
let sequence_length = DEMO_SEQUENCE_LENGTH;
let mut inputs = Vec::new();
let mut targets = Vec::new();

Expand Down
101 changes: 101 additions & 0 deletions tests/example_training_bounds_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ mod stock_prediction;
#[path = "../examples/early_stopping_example.rs"]
mod early_stopping_example;

#[allow(dead_code)]
#[path = "../examples/learning_rate_scheduling.rs"]
mod learning_rate_scheduling;

use std::hint::black_box;

#[test]
Expand Down Expand Up @@ -128,3 +132,100 @@ fn early_stopping_example_uses_small_deterministic_fixture() {
"early_stopping_example should keep each fixture sequence bounded"
);
}

#[test]
fn learning_rate_scheduling_applies_bounded_configs_to_all_demo_paths() {
let configs = [
learning_rate_scheduling::step_lr_training_config(),
learning_rate_scheduling::one_cycle_training_config(),
learning_rate_scheduling::cosine_annealing_training_config(),
learning_rate_scheduling::exponential_decay_training_config(),
learning_rate_scheduling::reduce_on_plateau_training_config(),
learning_rate_scheduling::scheduler_comparison_training_config(),
];

for config in configs {
assert!(
black_box(config.epochs) <= 5,
"learning_rate_scheduling should keep every demo training path bounded"
);
assert!(
black_box(config.print_every) > 0,
"learning_rate_scheduling progress logging should stay enabled"
);
assert!(
black_box(config.print_every) <= black_box(config.epochs),
"learning_rate_scheduling progress logging should not exceed the epoch budget"
);
assert!(
config.early_stopping.is_none(),
"learning_rate_scheduling examples should avoid hidden early-stopping work"
);
}

assert!(
black_box(learning_rate_scheduling::DEMO_HIDDEN_SIZE) <= 4,
"learning_rate_scheduling should keep demo hidden size bounded"
);
assert!(
black_box(learning_rate_scheduling::DEMO_COMPARISON_HIDDEN_SIZE) <= 4,
"learning_rate_scheduling should keep comparison hidden size bounded"
);
assert!(
black_box(learning_rate_scheduling::DEMO_STEP_PERIOD) > 0,
"step scheduler period should be non-zero"
);
assert!(
black_box(learning_rate_scheduling::DEMO_STEP_PERIOD)
<= black_box(learning_rate_scheduling::DEMO_STEP_EPOCHS),
"step scheduler period should fit inside the demo epoch budget"
);
assert!(
black_box(learning_rate_scheduling::DEMO_COSINE_PERIOD) > 0,
"cosine scheduler period should be non-zero"
);
assert!(
black_box(learning_rate_scheduling::DEMO_COSINE_PERIOD)
<= black_box(learning_rate_scheduling::DEMO_COSINE_EPOCHS),
"cosine scheduler period should fit inside the demo epoch budget"
);
assert!(
black_box(learning_rate_scheduling::DEMO_PLATEAU_PATIENCE) > 0,
"plateau patience should be non-zero"
);
assert!(
black_box(learning_rate_scheduling::DEMO_PLATEAU_PATIENCE)
<= black_box(learning_rate_scheduling::DEMO_PLATEAU_EPOCHS),
"plateau patience should fit inside the demo epoch budget"
);
}

#[test]
fn learning_rate_scheduling_uses_small_deterministic_fixture() {
let first = learning_rate_scheduling::generate_sine_wave_data(black_box(4), 0.0);
let second = learning_rate_scheduling::generate_sine_wave_data(black_box(4), 0.0);

assert_eq!(
first, second,
"demo sine-wave fixture should be deterministic"
);
assert!(
black_box(learning_rate_scheduling::DEMO_TRAIN_SEQUENCES) <= 16,
"learning_rate_scheduling should keep training sequence count bounded"
);
assert!(
black_box(learning_rate_scheduling::DEMO_VAL_SEQUENCES) <= 4,
"learning_rate_scheduling should keep validation sequence count bounded"
);
assert!(
black_box(learning_rate_scheduling::DEMO_SEQUENCE_LENGTH) <= 6,
"learning_rate_scheduling should keep each fixture sequence bounded"
);
assert!(
first.iter().all(|(inputs, targets)| {
inputs.len() == learning_rate_scheduling::DEMO_SEQUENCE_LENGTH
&& targets.len() == learning_rate_scheduling::DEMO_SEQUENCE_LENGTH
}),
"generated fixtures should use the public demo sequence-length bound"
);
}
Loading