diff --git a/examples/stock_prediction.rs b/examples/stock_prediction.rs index 52eb7a7..abe6c81 100644 --- a/examples/stock_prediction.rs +++ b/examples/stock_prediction.rs @@ -7,10 +7,20 @@ #![allow(unused_comparisons)] use ndarray::{arr2, Array2}; +use rand::{rngs::StdRng, Rng, SeedableRng}; use rust_lstm::loss::MSELoss; use rust_lstm::models::lstm_network::LSTMNetwork; use rust_lstm::optimizers::Adam; -use rust_lstm::training::LSTMTrainer; +use rust_lstm::training::{LSTMTrainer, TrainingConfig}; + +fn demo_training_config() -> TrainingConfig { + TrainingConfig { + epochs: 2, + print_every: 1, + log_lr_changes: false, + ..TrainingConfig::default() + } +} /// Stock data point with OHLCV (Open, High, Low, Close, Volume) #[derive(Debug, Clone)] @@ -140,7 +150,8 @@ impl StockPredictor { // Create trainer with Adam optimizer let loss_function = MSELoss; let optimizer = Adam::new(0.001); - let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer); + let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer) + .with_config(demo_training_config()); // Train the model trainer.train(train_data, Some(val_data)); @@ -178,6 +189,7 @@ impl StockPredictor { /// Generate synthetic stock data for demonstration fn generate_stock_data(days: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(42); let mut data = Vec::new(); let mut price = 100.0; let volume_base = 1_000_000.0; @@ -186,19 +198,19 @@ fn generate_stock_data(days: usize) -> Vec { // Random walk with trend and volatility let trend = 0.001; // Slight upward trend let volatility = 0.02; - let random_change = (rand::random::() - 0.5) * volatility; + let random_change = (rng.gen::() - 0.5) * volatility; price *= 1.0 + trend + random_change; price = price.max(1.0); // Prevent negative prices // Generate OHLC based on closing price let daily_volatility = 0.005; - let high = price * (1.0 + rand::random::() * daily_volatility); - let low = price * (1.0 - rand::random::() * daily_volatility); - let open = low + (high - low) * rand::random::(); + let high = price * (1.0 + rng.gen::() * daily_volatility); + let low = price * (1.0 - rng.gen::() * daily_volatility); + let open = low + (high - low) * rng.gen::(); // Volume with some correlation to price movement - let volume_factor = 0.8 + 0.4 * rand::random::(); + let volume_factor = 0.8 + 0.4 * rng.gen::(); let volume = volume_base * volume_factor; data.push(StockData { @@ -217,9 +229,10 @@ fn generate_stock_data(days: usize) -> Vec { fn main() { println!("šŸ¦ Stock Price Prediction with LSTM"); println!("=====================================\n"); + println!("This bounded demo favors quick execution over prediction quality.\n"); // Generate synthetic stock data (in practice, you'd load real data) - let stock_data = generate_stock_data(500); // 500 days of data + let stock_data = generate_stock_data(80); // Bounded synthetic dataset for a quick demo println!( "šŸ“ˆ Generated {} days of synthetic stock data", stock_data.len() @@ -237,16 +250,16 @@ fn main() { } // Create and train predictor - let mut predictor = StockPredictor::new(20, 50); // 20-day sequences, 50 hidden units + let mut predictor = StockPredictor::new(5, 8); // 5-day sequences, 8 hidden units predictor.train(&stock_data, 0.2); // 80% train, 20% validation // Make predictions on recent data println!("\nšŸ”® Making predictions..."); - let recent_data = &stock_data[stock_data.len() - 30..]; // Last 30 days + let recent_data = &stock_data[stock_data.len() - 10..]; // Last 10 days - for i in 20..25 { - // Predict for days 21-25 of recent data - let input_data = &recent_data[i - 20..i]; + for i in 5..10 { + // Predict for days 6-10 of recent data + let input_data = &recent_data[i - 5..i]; if let Some(predicted_price) = predictor.predict_next_price(input_data) { let actual_price = recent_data[i].close; let error = (predicted_price - actual_price).abs(); diff --git a/examples/training_example.rs b/examples/training_example.rs index 7f460a9..b47b256 100644 --- a/examples/training_example.rs +++ b/examples/training_example.rs @@ -10,7 +10,16 @@ use ndarray::{arr2, Array2}; use rust_lstm::loss::MSELoss; use rust_lstm::models::lstm_network::LSTMNetwork; use rust_lstm::optimizers::{Adam, SGD}; -use rust_lstm::training::LSTMTrainer; +use rust_lstm::training::{LSTMTrainer, TrainingConfig}; + +fn demo_training_config() -> TrainingConfig { + TrainingConfig { + epochs: 5, + print_every: 1, + log_lr_changes: false, + ..TrainingConfig::default() + } +} /// Generate sine wave training data for sequence prediction fn generate_sine_data( @@ -87,7 +96,8 @@ fn main() { // Training with SGD println!("Training with SGD optimizer:"); let network = LSTMNetwork::new(input_size, hidden_size, num_layers); - let mut trainer_sgd = LSTMTrainer::new(network, MSELoss, SGD::new(0.01)); + let mut trainer_sgd = + LSTMTrainer::new(network, MSELoss, SGD::new(0.01)).with_config(demo_training_config()); trainer_sgd.train(&train_data, Some(&val_data)); @@ -109,7 +119,8 @@ fn main() { // Training with Adam println!("Training with Adam optimizer:"); let network = LSTMNetwork::new(input_size, hidden_size, num_layers); - let mut trainer_adam = LSTMTrainer::new(network, MSELoss, Adam::new(0.001)); + let mut trainer_adam = + LSTMTrainer::new(network, MSELoss, Adam::new(0.001)).with_config(demo_training_config()); trainer_adam.train(&train_data, Some(&val_data)); diff --git a/tests/example_training_bounds_test.rs b/tests/example_training_bounds_test.rs new file mode 100644 index 0000000..000971c --- /dev/null +++ b/tests/example_training_bounds_test.rs @@ -0,0 +1,67 @@ +#[test] +fn training_example_uses_bounded_demo_epochs() { + let source = include_str!("../examples/training_example.rs"); + + assert!( + source.contains("epochs: 5"), + "training_example should keep interactive demo training bounded" + ); + assert!( + source.contains("fn demo_training_config() -> TrainingConfig"), + "training_example should centralize its bounded demo training config" + ); + assert!( + source + .matches(".with_config(demo_training_config())") + .count() + >= 2, + "both optimizer demonstrations should apply the bounded demo config" + ); +} + +#[test] +fn stock_prediction_example_uses_bounded_demo_epochs() { + let source = include_str!("../examples/stock_prediction.rs"); + + assert!( + source.contains("epochs: 2"), + "stock_prediction should avoid the default 100-epoch training config" + ); + assert!( + source.contains("StdRng::seed_from_u64(42)"), + "stock_prediction should keep synthetic data generation reproducible" + ); + assert!( + source.contains("generate_stock_data(80)"), + "stock_prediction should keep its synthetic dataset bounded for interactive runs" + ); + assert!( + source.contains("StockPredictor::new(5, 8)"), + "stock_prediction should keep sequence length and hidden size bounded for interactive runs" + ); + assert!( + source.contains(".with_config(demo_training_config())"), + "stock_prediction should apply the bounded demo config before training" + ); +} + +#[test] +fn batch_processing_example_keeps_scalability_demo_bounded() { + let source = include_str!("../examples/batch_processing_example.rs"); + + for expected in [ + "trainer1.config.epochs = 5", + "trainer2.config.epochs = 5", + "trainer3.config.epochs = 5", + ] { + assert!( + source.contains(expected), + "batch benchmark trainer should keep a short epoch budget: {expected}" + ); + } + + assert!( + source.contains("trainer.config.epochs = 3"), + "scalability loop should keep a short epoch budget" + ); +}