Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions examples/stock_prediction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,20 @@
#![allow(unused_comparisons)]

use ndarray::{arr2, Array2};
use rand::{rngs::StdRng, Rng, SeedableRng};
use rust_lstm::loss::MSELoss;
use rust_lstm::models::lstm_network::LSTMNetwork;
use rust_lstm::optimizers::Adam;
use rust_lstm::training::LSTMTrainer;
use rust_lstm::training::{LSTMTrainer, TrainingConfig};

fn demo_training_config() -> TrainingConfig {
TrainingConfig {
epochs: 2,
print_every: 1,
log_lr_changes: false,
..TrainingConfig::default()
}
}

/// Stock data point with OHLCV (Open, High, Low, Close, Volume)
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -140,7 +150,8 @@ impl StockPredictor {
// Create trainer with Adam optimizer
let loss_function = MSELoss;
let optimizer = Adam::new(0.001);
let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer);
let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer)
.with_config(demo_training_config());

// Train the model
trainer.train(train_data, Some(val_data));
Expand Down Expand Up @@ -178,6 +189,7 @@ impl StockPredictor {

/// Generate synthetic stock data for demonstration
fn generate_stock_data(days: usize) -> Vec<StockData> {
let mut rng = StdRng::seed_from_u64(42);
let mut data = Vec::new();
let mut price = 100.0;
let volume_base = 1_000_000.0;
Expand All @@ -186,19 +198,19 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
// Random walk with trend and volatility
let trend = 0.001; // Slight upward trend
let volatility = 0.02;
let random_change = (rand::random::<f64>() - 0.5) * volatility;
let random_change = (rng.gen::<f64>() - 0.5) * volatility;

price *= 1.0 + trend + random_change;
price = price.max(1.0); // Prevent negative prices

// Generate OHLC based on closing price
let daily_volatility = 0.005;
let high = price * (1.0 + rand::random::<f64>() * daily_volatility);
let low = price * (1.0 - rand::random::<f64>() * daily_volatility);
let open = low + (high - low) * rand::random::<f64>();
let high = price * (1.0 + rng.gen::<f64>() * daily_volatility);
let low = price * (1.0 - rng.gen::<f64>() * daily_volatility);
let open = low + (high - low) * rng.gen::<f64>();

// Volume with some correlation to price movement
let volume_factor = 0.8 + 0.4 * rand::random::<f64>();
let volume_factor = 0.8 + 0.4 * rng.gen::<f64>();
let volume = volume_base * volume_factor;

data.push(StockData {
Expand All @@ -217,9 +229,10 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
fn main() {
println!("🏦 Stock Price Prediction with LSTM");
println!("=====================================\n");
println!("This bounded demo favors quick execution over prediction quality.\n");

// Generate synthetic stock data (in practice, you'd load real data)
let stock_data = generate_stock_data(500); // 500 days of data
let stock_data = generate_stock_data(80); // Bounded synthetic dataset for a quick demo
println!(
"📈 Generated {} days of synthetic stock data",
stock_data.len()
Expand All @@ -237,16 +250,16 @@ fn main() {
}

// Create and train predictor
let mut predictor = StockPredictor::new(20, 50); // 20-day sequences, 50 hidden units
let mut predictor = StockPredictor::new(5, 8); // 5-day sequences, 8 hidden units
predictor.train(&stock_data, 0.2); // 80% train, 20% validation

// Make predictions on recent data
println!("\n🔮 Making predictions...");
let recent_data = &stock_data[stock_data.len() - 30..]; // Last 30 days
let recent_data = &stock_data[stock_data.len() - 10..]; // Last 10 days

for i in 20..25 {
// Predict for days 21-25 of recent data
let input_data = &recent_data[i - 20..i];
for i in 5..10 {
// Predict for days 6-10 of recent data
let input_data = &recent_data[i - 5..i];
if let Some(predicted_price) = predictor.predict_next_price(input_data) {
let actual_price = recent_data[i].close;
let error = (predicted_price - actual_price).abs();
Expand Down
17 changes: 14 additions & 3 deletions examples/training_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@ use ndarray::{arr2, Array2};
use rust_lstm::loss::MSELoss;
use rust_lstm::models::lstm_network::LSTMNetwork;
use rust_lstm::optimizers::{Adam, SGD};
use rust_lstm::training::LSTMTrainer;
use rust_lstm::training::{LSTMTrainer, TrainingConfig};

fn demo_training_config() -> TrainingConfig {
TrainingConfig {
epochs: 5,
print_every: 1,
log_lr_changes: false,
..TrainingConfig::default()
}
}

/// Generate sine wave training data for sequence prediction
fn generate_sine_data(
Expand Down Expand Up @@ -87,7 +96,8 @@ fn main() {
// Training with SGD
println!("Training with SGD optimizer:");
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer_sgd = LSTMTrainer::new(network, MSELoss, SGD::new(0.01));
let mut trainer_sgd =
LSTMTrainer::new(network, MSELoss, SGD::new(0.01)).with_config(demo_training_config());

trainer_sgd.train(&train_data, Some(&val_data));

Expand All @@ -109,7 +119,8 @@ fn main() {
// Training with Adam
println!("Training with Adam optimizer:");
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer_adam = LSTMTrainer::new(network, MSELoss, Adam::new(0.001));
let mut trainer_adam =
LSTMTrainer::new(network, MSELoss, Adam::new(0.001)).with_config(demo_training_config());

trainer_adam.train(&train_data, Some(&val_data));

Expand Down
67 changes: 67 additions & 0 deletions tests/example_training_bounds_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#[test]
fn training_example_uses_bounded_demo_epochs() {
let source = include_str!("../examples/training_example.rs");

assert!(
source.contains("epochs: 5"),
"training_example should keep interactive demo training bounded"
);
assert!(
source.contains("fn demo_training_config() -> TrainingConfig"),
"training_example should centralize its bounded demo training config"
);
assert!(
source
.matches(".with_config(demo_training_config())")
.count()
>= 2,
"both optimizer demonstrations should apply the bounded demo config"
);
}

#[test]
fn stock_prediction_example_uses_bounded_demo_epochs() {
let source = include_str!("../examples/stock_prediction.rs");

assert!(
source.contains("epochs: 2"),
"stock_prediction should avoid the default 100-epoch training config"
);
assert!(
source.contains("StdRng::seed_from_u64(42)"),
"stock_prediction should keep synthetic data generation reproducible"
);
assert!(
source.contains("generate_stock_data(80)"),
"stock_prediction should keep its synthetic dataset bounded for interactive runs"
);
assert!(
source.contains("StockPredictor::new(5, 8)"),
"stock_prediction should keep sequence length and hidden size bounded for interactive runs"
);
assert!(
source.contains(".with_config(demo_training_config())"),
"stock_prediction should apply the bounded demo config before training"
);
}

#[test]
fn batch_processing_example_keeps_scalability_demo_bounded() {
let source = include_str!("../examples/batch_processing_example.rs");

for expected in [
"trainer1.config.epochs = 5",
"trainer2.config.epochs = 5",
"trainer3.config.epochs = 5",
] {
assert!(
source.contains(expected),
"batch benchmark trainer should keep a short epoch budget: {expected}"
);
}

assert!(
source.contains("trainer.config.epochs = 3"),
"scalability loop should keep a short epoch budget"
);
}
Loading