Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions examples/dropout_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ use rust_lstm::{
LSTMNetwork, LayerDropoutConfig,
};

pub const DEMO_TRAIN_SEQUENCES: usize = 8;
pub const DEMO_SEQUENCE_LENGTH: usize = 4;
pub const DEMO_HIDDEN_SIZE: usize = 4;
pub const DEMO_EPOCHS: usize = 4;
pub const DEMO_PRINT_EVERY: usize = 1;

fn main() {
println!("Rust LSTM Dropout Example");
println!("=========================\n");
Expand Down Expand Up @@ -190,7 +196,7 @@ fn demonstrate_training_with_dropout() {
println!("------------------------");

let input_size = 2;
let hidden_size = 4;
let hidden_size = DEMO_HIDDEN_SIZE;
let num_layers = 2;

// Create network with comprehensive dropout
Expand All @@ -204,19 +210,10 @@ fn demonstrate_training_with_dropout() {
let loss_function = MSELoss;
let optimizer = Adam::new(0.001);
let mut trainer = LSTMTrainer::new(network, loss_function, optimizer);

// Configure training
let config = TrainingConfig {
epochs: 20,
print_every: 5,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: None,
};
trainer = trainer.with_config(config);
trainer = trainer.with_config(dropout_training_config());

// Generate simple training data (sine wave prediction)
let train_data = generate_sine_wave_data(10, 5);
let train_data = generate_sine_wave_data(DEMO_TRAIN_SEQUENCES, DEMO_SEQUENCE_LENGTH);

println!("Training LSTM with dropout regularization...");
println!(
Expand Down Expand Up @@ -251,7 +248,17 @@ fn demonstrate_training_with_dropout() {
println!("\nTraining completed with dropout regularization!");
}

fn generate_sine_wave_data(
pub fn dropout_training_config() -> TrainingConfig {
TrainingConfig {
epochs: DEMO_EPOCHS,
print_every: DEMO_PRINT_EVERY,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: None,
}
}

pub fn generate_sine_wave_data(
num_sequences: usize,
sequence_length: usize,
) -> Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)> {
Expand Down
62 changes: 62 additions & 0 deletions tests/example_training_bounds_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ mod learning_rate_scheduling;
#[path = "../examples/advanced_lr_scheduling.rs"]
mod advanced_lr_scheduling;

#[allow(dead_code)]
#[path = "../examples/dropout_example.rs"]
mod dropout_example;

use std::hint::black_box;

#[test]
Expand Down Expand Up @@ -234,6 +238,64 @@ fn learning_rate_scheduling_uses_small_deterministic_fixture() {
);
}

#[test]
fn dropout_example_applies_bounded_training_config() {
let config = dropout_example::dropout_training_config();

assert!(
black_box(config.epochs) <= 4,
"dropout_example should avoid the default 100-epoch training config"
);
assert!(
black_box(config.print_every) > 0,
"dropout_example progress logging should stay enabled"
);
assert!(
black_box(config.print_every) <= black_box(config.epochs),
"dropout_example progress logging should not exceed the epoch budget"
);
assert!(
config.early_stopping.is_none(),
"dropout_example should avoid hidden early-stopping work"
);
assert!(
black_box(dropout_example::DEMO_HIDDEN_SIZE) <= 4,
"dropout_example should keep demo hidden size bounded"
);
}

#[test]
fn dropout_example_uses_small_deterministic_fixture() {
let first = dropout_example::generate_sine_wave_data(
black_box(dropout_example::DEMO_TRAIN_SEQUENCES),
black_box(dropout_example::DEMO_SEQUENCE_LENGTH),
);
let second = dropout_example::generate_sine_wave_data(
black_box(dropout_example::DEMO_TRAIN_SEQUENCES),
black_box(dropout_example::DEMO_SEQUENCE_LENGTH),
);

assert_eq!(
first, second,
"dropout demo fixture should be deterministic"
);
assert!(
black_box(dropout_example::DEMO_TRAIN_SEQUENCES) <= 8,
"dropout_example should keep training sequence count bounded"
);
assert!(
black_box(dropout_example::DEMO_SEQUENCE_LENGTH) <= 4,
"dropout_example should keep each fixture sequence bounded"
);
assert!(
first.iter().all(|(inputs, targets)| {
inputs.len() == dropout_example::DEMO_SEQUENCE_LENGTH
&& targets.len() == dropout_example::DEMO_SEQUENCE_LENGTH
}),
"generated fixtures should use the public demo sequence-length bound"
);
}

#[test]
fn advanced_lr_scheduling_applies_bounded_configs_to_all_demo_paths() {
let configs = [
Expand Down
Loading