From 1301772fd7b3f5b2084cd6bf931d418348a11813 Mon Sep 17 00:00:00 2001 From: kholdrex Date: Mon, 1 Jun 2026 14:53:18 -0500 Subject: [PATCH] test: bound example training runs --- examples/stock_prediction.rs | 45 +++++++++----- examples/training_example.rs | 43 ++++++++++--- tests/example_training_bounds_test.rs | 89 ++++++++++++++------------- 3 files changed, 109 insertions(+), 68 deletions(-) diff --git a/examples/stock_prediction.rs b/examples/stock_prediction.rs index abe6c81..857342d 100644 --- a/examples/stock_prediction.rs +++ b/examples/stock_prediction.rs @@ -13,15 +13,26 @@ use rust_lstm::models::lstm_network::LSTMNetwork; use rust_lstm::optimizers::Adam; use rust_lstm::training::{LSTMTrainer, TrainingConfig}; -fn demo_training_config() -> TrainingConfig { +pub(crate) const DEMO_EPOCHS: usize = 2; +pub(crate) const DEMO_PRINT_EVERY: usize = 1; +pub(crate) const DEMO_STOCK_DAYS: usize = 80; +pub(crate) const DEMO_SEQUENCE_LENGTH: usize = 5; +pub(crate) const DEMO_HIDDEN_SIZE: usize = 8; +pub(crate) const DEMO_RNG_SEED: u64 = 42; + +pub(crate) fn demo_training_config() -> TrainingConfig { TrainingConfig { - epochs: 2, - print_every: 1, + epochs: DEMO_EPOCHS, + print_every: DEMO_PRINT_EVERY, log_lr_changes: false, ..TrainingConfig::default() } } +pub(crate) fn demo_stock_trainer(network: LSTMNetwork) -> LSTMTrainer { + LSTMTrainer::new(network, MSELoss, Adam::new(0.001)).with_config(demo_training_config()) +} + /// Stock data point with OHLCV (Open, High, Low, Close, Volume) #[derive(Debug, Clone)] #[allow(dead_code)] @@ -147,11 +158,7 @@ impl StockPredictor { val_data.len() ); - // Create trainer with Adam optimizer - let loss_function = MSELoss; - let optimizer = Adam::new(0.001); - let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer) - .with_config(demo_training_config()); + let mut trainer = demo_stock_trainer(self.network.clone()); // Train the model trainer.train(train_data, Some(val_data)); @@ -189,7 +196,7 @@ impl StockPredictor { /// Generate synthetic stock data for demonstration fn generate_stock_data(days: usize) -> Vec { - let mut rng = StdRng::seed_from_u64(42); + let mut rng = StdRng::seed_from_u64(DEMO_RNG_SEED); let mut data = Vec::new(); let mut price = 100.0; let volume_base = 1_000_000.0; @@ -226,13 +233,21 @@ fn generate_stock_data(days: usize) -> Vec { data } +#[allow(dead_code)] +pub(crate) fn demo_stock_closes(days: usize) -> Vec { + generate_stock_data(days) + .into_iter() + .map(|stock| stock.close) + .collect() +} + fn main() { println!("šŸ¦ Stock Price Prediction with LSTM"); println!("=====================================\n"); println!("This bounded demo favors quick execution over prediction quality.\n"); // Generate synthetic stock data (in practice, you'd load real data) - let stock_data = generate_stock_data(80); // Bounded synthetic dataset for a quick demo + let stock_data = generate_stock_data(DEMO_STOCK_DAYS); // Bounded synthetic dataset for a quick demo println!( "šŸ“ˆ Generated {} days of synthetic stock data", stock_data.len() @@ -250,16 +265,16 @@ fn main() { } // Create and train predictor - let mut predictor = StockPredictor::new(5, 8); // 5-day sequences, 8 hidden units + let mut predictor = StockPredictor::new(DEMO_SEQUENCE_LENGTH, DEMO_HIDDEN_SIZE); predictor.train(&stock_data, 0.2); // 80% train, 20% validation // Make predictions on recent data println!("\nšŸ”® Making predictions..."); - let recent_data = &stock_data[stock_data.len() - 10..]; // Last 10 days + let recent_window = DEMO_SEQUENCE_LENGTH * 2; + let recent_data = &stock_data[stock_data.len() - recent_window..]; - for i in 5..10 { - // Predict for days 6-10 of recent data - let input_data = &recent_data[i - 5..i]; + for i in DEMO_SEQUENCE_LENGTH..recent_window { + let input_data = &recent_data[i - DEMO_SEQUENCE_LENGTH..i]; if let Some(predicted_price) = predictor.predict_next_price(input_data) { let actual_price = recent_data[i].close; let error = (predicted_price - actual_price).abs(); diff --git a/examples/training_example.rs b/examples/training_example.rs index b47b256..c11737b 100644 --- a/examples/training_example.rs +++ b/examples/training_example.rs @@ -12,15 +12,44 @@ use rust_lstm::models::lstm_network::LSTMNetwork; use rust_lstm::optimizers::{Adam, SGD}; use rust_lstm::training::{LSTMTrainer, TrainingConfig}; -fn demo_training_config() -> TrainingConfig { +pub(crate) const DEMO_EPOCHS: usize = 5; +pub(crate) const DEMO_PRINT_EVERY: usize = 1; + +pub(crate) fn demo_training_config() -> TrainingConfig { TrainingConfig { - epochs: 5, - print_every: 1, + epochs: DEMO_EPOCHS, + print_every: DEMO_PRINT_EVERY, log_lr_changes: false, ..TrainingConfig::default() } } +pub(crate) fn demo_sgd_trainer( + input_size: usize, + hidden_size: usize, + num_layers: usize, +) -> LSTMTrainer { + LSTMTrainer::new( + LSTMNetwork::new(input_size, hidden_size, num_layers), + MSELoss, + SGD::new(0.01), + ) + .with_config(demo_training_config()) +} + +pub(crate) fn demo_adam_trainer( + input_size: usize, + hidden_size: usize, + num_layers: usize, +) -> LSTMTrainer { + LSTMTrainer::new( + LSTMNetwork::new(input_size, hidden_size, num_layers), + MSELoss, + Adam::new(0.001), + ) + .with_config(demo_training_config()) +} + /// Generate sine wave training data for sequence prediction fn generate_sine_data( num_sequences: usize, @@ -95,9 +124,7 @@ fn main() { // Training with SGD println!("Training with SGD optimizer:"); - let network = LSTMNetwork::new(input_size, hidden_size, num_layers); - let mut trainer_sgd = - LSTMTrainer::new(network, MSELoss, SGD::new(0.01)).with_config(demo_training_config()); + let mut trainer_sgd = demo_sgd_trainer(input_size, hidden_size, num_layers); trainer_sgd.train(&train_data, Some(&val_data)); @@ -118,9 +145,7 @@ fn main() { // Training with Adam println!("Training with Adam optimizer:"); - let network = LSTMNetwork::new(input_size, hidden_size, num_layers); - let mut trainer_adam = - LSTMTrainer::new(network, MSELoss, Adam::new(0.001)).with_config(demo_training_config()); + let mut trainer_adam = demo_adam_trainer(input_size, hidden_size, num_layers); trainer_adam.train(&train_data, Some(&val_data)); diff --git a/tests/example_training_bounds_test.rs b/tests/example_training_bounds_test.rs index 000971c..4f0d498 100644 --- a/tests/example_training_bounds_test.rs +++ b/tests/example_training_bounds_test.rs @@ -1,67 +1,68 @@ +#[allow(dead_code)] +#[path = "../examples/training_example.rs"] +mod training_example; + +#[allow(dead_code)] +#[path = "../examples/stock_prediction.rs"] +mod stock_prediction; + +use std::hint::black_box; + #[test] -fn training_example_uses_bounded_demo_epochs() { - let source = include_str!("../examples/training_example.rs"); +fn training_example_applies_bounded_config_to_both_demo_trainers() { + let sgd_trainer = training_example::demo_sgd_trainer(1, 10, 1); + let adam_trainer = training_example::demo_adam_trainer(1, 10, 1); - assert!( - source.contains("epochs: 5"), - "training_example should keep interactive demo training bounded" - ); - assert!( - source.contains("fn demo_training_config() -> TrainingConfig"), - "training_example should centralize its bounded demo training config" - ); - assert!( - source - .matches(".with_config(demo_training_config())") - .count() - >= 2, - "both optimizer demonstrations should apply the bounded demo config" - ); + for epochs in [sgd_trainer.config.epochs, adam_trainer.config.epochs] { + assert!( + black_box(epochs) <= 5, + "training_example should keep both optimizer demos bounded" + ); + } + + for print_every in [ + sgd_trainer.config.print_every, + adam_trainer.config.print_every, + ] { + assert!( + black_box(print_every) <= black_box(training_example::DEMO_EPOCHS), + "training_example should report progress without exceeding the epoch budget" + ); + } } #[test] -fn stock_prediction_example_uses_bounded_demo_epochs() { - let source = include_str!("../examples/stock_prediction.rs"); +fn stock_prediction_applies_bounded_config_to_demo_trainer() { + let network = + rust_lstm::models::lstm_network::LSTMNetwork::new(5, stock_prediction::DEMO_HIDDEN_SIZE, 2); + let trainer = stock_prediction::demo_stock_trainer(network); assert!( - source.contains("epochs: 2"), + black_box(trainer.config.epochs) <= 2, "stock_prediction should avoid the default 100-epoch training config" ); assert!( - source.contains("StdRng::seed_from_u64(42)"), - "stock_prediction should keep synthetic data generation reproducible" - ); - assert!( - source.contains("generate_stock_data(80)"), + black_box(stock_prediction::DEMO_STOCK_DAYS) <= 80, "stock_prediction should keep its synthetic dataset bounded for interactive runs" ); assert!( - source.contains("StockPredictor::new(5, 8)"), - "stock_prediction should keep sequence length and hidden size bounded for interactive runs" + black_box(stock_prediction::DEMO_SEQUENCE_LENGTH) <= 5, + "stock_prediction should keep sequence length bounded for interactive runs" ); assert!( - source.contains(".with_config(demo_training_config())"), - "stock_prediction should apply the bounded demo config before training" + black_box(stock_prediction::DEMO_HIDDEN_SIZE) <= 8, + "stock_prediction should keep hidden size bounded for interactive runs" ); } #[test] -fn batch_processing_example_keeps_scalability_demo_bounded() { - let source = include_str!("../examples/batch_processing_example.rs"); - - for expected in [ - "trainer1.config.epochs = 5", - "trainer2.config.epochs = 5", - "trainer3.config.epochs = 5", - ] { - assert!( - source.contains(expected), - "batch benchmark trainer should keep a short epoch budget: {expected}" - ); - } +fn stock_prediction_demo_data_is_reproducible() { + let first = stock_prediction::demo_stock_closes(black_box(8)); + let second = stock_prediction::demo_stock_closes(black_box(8)); + assert_eq!(first, second, "seeded demo data should be reproducible"); assert!( - source.contains("trainer.config.epochs = 3"), - "scalability loop should keep a short epoch budget" + first.windows(2).any(|window| window[0] != window[1]), + "demo data should still vary across generated days" ); }