Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions examples/stock_prediction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,26 @@ use rust_lstm::models::lstm_network::LSTMNetwork;
use rust_lstm::optimizers::Adam;
use rust_lstm::training::{LSTMTrainer, TrainingConfig};

fn demo_training_config() -> TrainingConfig {
pub(crate) const DEMO_EPOCHS: usize = 2;
pub(crate) const DEMO_PRINT_EVERY: usize = 1;
pub(crate) const DEMO_STOCK_DAYS: usize = 80;
pub(crate) const DEMO_SEQUENCE_LENGTH: usize = 5;
pub(crate) const DEMO_HIDDEN_SIZE: usize = 8;
pub(crate) const DEMO_RNG_SEED: u64 = 42;

pub(crate) fn demo_training_config() -> TrainingConfig {
TrainingConfig {
epochs: 2,
print_every: 1,
epochs: DEMO_EPOCHS,
print_every: DEMO_PRINT_EVERY,
log_lr_changes: false,
..TrainingConfig::default()
}
}

pub(crate) fn demo_stock_trainer(network: LSTMNetwork) -> LSTMTrainer<MSELoss, Adam> {
LSTMTrainer::new(network, MSELoss, Adam::new(0.001)).with_config(demo_training_config())
}

/// Stock data point with OHLCV (Open, High, Low, Close, Volume)
#[derive(Debug, Clone)]
#[allow(dead_code)]
Expand Down Expand Up @@ -147,11 +158,7 @@ impl StockPredictor {
val_data.len()
);

// Create trainer with Adam optimizer
let loss_function = MSELoss;
let optimizer = Adam::new(0.001);
let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer)
.with_config(demo_training_config());
let mut trainer = demo_stock_trainer(self.network.clone());

// Train the model
trainer.train(train_data, Some(val_data));
Expand Down Expand Up @@ -189,7 +196,7 @@ impl StockPredictor {

/// Generate synthetic stock data for demonstration
fn generate_stock_data(days: usize) -> Vec<StockData> {
let mut rng = StdRng::seed_from_u64(42);
let mut rng = StdRng::seed_from_u64(DEMO_RNG_SEED);
let mut data = Vec::new();
let mut price = 100.0;
let volume_base = 1_000_000.0;
Expand Down Expand Up @@ -226,13 +233,21 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
data
}

#[allow(dead_code)]
pub(crate) fn demo_stock_closes(days: usize) -> Vec<f64> {
generate_stock_data(days)
.into_iter()
.map(|stock| stock.close)
.collect()
}

fn main() {
println!("🏦 Stock Price Prediction with LSTM");
println!("=====================================\n");
println!("This bounded demo favors quick execution over prediction quality.\n");

// Generate synthetic stock data (in practice, you'd load real data)
let stock_data = generate_stock_data(80); // Bounded synthetic dataset for a quick demo
let stock_data = generate_stock_data(DEMO_STOCK_DAYS); // Bounded synthetic dataset for a quick demo
println!(
"📈 Generated {} days of synthetic stock data",
stock_data.len()
Expand All @@ -250,16 +265,16 @@ fn main() {
}

// Create and train predictor
let mut predictor = StockPredictor::new(5, 8); // 5-day sequences, 8 hidden units
let mut predictor = StockPredictor::new(DEMO_SEQUENCE_LENGTH, DEMO_HIDDEN_SIZE);
predictor.train(&stock_data, 0.2); // 80% train, 20% validation

// Make predictions on recent data
println!("\n🔮 Making predictions...");
let recent_data = &stock_data[stock_data.len() - 10..]; // Last 10 days
let recent_window = DEMO_SEQUENCE_LENGTH * 2;
let recent_data = &stock_data[stock_data.len() - recent_window..];

for i in 5..10 {
// Predict for days 6-10 of recent data
let input_data = &recent_data[i - 5..i];
for i in DEMO_SEQUENCE_LENGTH..recent_window {
let input_data = &recent_data[i - DEMO_SEQUENCE_LENGTH..i];
if let Some(predicted_price) = predictor.predict_next_price(input_data) {
let actual_price = recent_data[i].close;
let error = (predicted_price - actual_price).abs();
Expand Down
43 changes: 34 additions & 9 deletions examples/training_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,44 @@ use rust_lstm::models::lstm_network::LSTMNetwork;
use rust_lstm::optimizers::{Adam, SGD};
use rust_lstm::training::{LSTMTrainer, TrainingConfig};

fn demo_training_config() -> TrainingConfig {
pub(crate) const DEMO_EPOCHS: usize = 5;
pub(crate) const DEMO_PRINT_EVERY: usize = 1;

pub(crate) fn demo_training_config() -> TrainingConfig {
TrainingConfig {
epochs: 5,
print_every: 1,
epochs: DEMO_EPOCHS,
print_every: DEMO_PRINT_EVERY,
log_lr_changes: false,
..TrainingConfig::default()
}
}

pub(crate) fn demo_sgd_trainer(
input_size: usize,
hidden_size: usize,
num_layers: usize,
) -> LSTMTrainer<MSELoss, SGD> {
LSTMTrainer::new(
LSTMNetwork::new(input_size, hidden_size, num_layers),
MSELoss,
SGD::new(0.01),
)
.with_config(demo_training_config())
}

pub(crate) fn demo_adam_trainer(
input_size: usize,
hidden_size: usize,
num_layers: usize,
) -> LSTMTrainer<MSELoss, Adam> {
LSTMTrainer::new(
LSTMNetwork::new(input_size, hidden_size, num_layers),
MSELoss,
Adam::new(0.001),
)
.with_config(demo_training_config())
}

/// Generate sine wave training data for sequence prediction
fn generate_sine_data(
num_sequences: usize,
Expand Down Expand Up @@ -95,9 +124,7 @@ fn main() {

// Training with SGD
println!("Training with SGD optimizer:");
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer_sgd =
LSTMTrainer::new(network, MSELoss, SGD::new(0.01)).with_config(demo_training_config());
let mut trainer_sgd = demo_sgd_trainer(input_size, hidden_size, num_layers);

trainer_sgd.train(&train_data, Some(&val_data));

Expand All @@ -118,9 +145,7 @@ fn main() {

// Training with Adam
println!("Training with Adam optimizer:");
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
let mut trainer_adam =
LSTMTrainer::new(network, MSELoss, Adam::new(0.001)).with_config(demo_training_config());
let mut trainer_adam = demo_adam_trainer(input_size, hidden_size, num_layers);

trainer_adam.train(&train_data, Some(&val_data));

Expand Down
89 changes: 45 additions & 44 deletions tests/example_training_bounds_test.rs
Original file line number Diff line number Diff line change
@@ -1,67 +1,68 @@
#[allow(dead_code)]
#[path = "../examples/training_example.rs"]
mod training_example;

#[allow(dead_code)]
#[path = "../examples/stock_prediction.rs"]
mod stock_prediction;

use std::hint::black_box;

#[test]
fn training_example_uses_bounded_demo_epochs() {
let source = include_str!("../examples/training_example.rs");
fn training_example_applies_bounded_config_to_both_demo_trainers() {
let sgd_trainer = training_example::demo_sgd_trainer(1, 10, 1);
let adam_trainer = training_example::demo_adam_trainer(1, 10, 1);

assert!(
source.contains("epochs: 5"),
"training_example should keep interactive demo training bounded"
);
assert!(
source.contains("fn demo_training_config() -> TrainingConfig"),
"training_example should centralize its bounded demo training config"
);
assert!(
source
.matches(".with_config(demo_training_config())")
.count()
>= 2,
"both optimizer demonstrations should apply the bounded demo config"
);
for epochs in [sgd_trainer.config.epochs, adam_trainer.config.epochs] {
assert!(
black_box(epochs) <= 5,
"training_example should keep both optimizer demos bounded"
);
}

for print_every in [
sgd_trainer.config.print_every,
adam_trainer.config.print_every,
] {
assert!(
black_box(print_every) <= black_box(training_example::DEMO_EPOCHS),
"training_example should report progress without exceeding the epoch budget"
);
}
}

#[test]
fn stock_prediction_example_uses_bounded_demo_epochs() {
let source = include_str!("../examples/stock_prediction.rs");
fn stock_prediction_applies_bounded_config_to_demo_trainer() {
let network =
rust_lstm::models::lstm_network::LSTMNetwork::new(5, stock_prediction::DEMO_HIDDEN_SIZE, 2);
let trainer = stock_prediction::demo_stock_trainer(network);

assert!(
source.contains("epochs: 2"),
black_box(trainer.config.epochs) <= 2,
"stock_prediction should avoid the default 100-epoch training config"
);
assert!(
source.contains("StdRng::seed_from_u64(42)"),
"stock_prediction should keep synthetic data generation reproducible"
);
assert!(
source.contains("generate_stock_data(80)"),
black_box(stock_prediction::DEMO_STOCK_DAYS) <= 80,
"stock_prediction should keep its synthetic dataset bounded for interactive runs"
);
assert!(
source.contains("StockPredictor::new(5, 8)"),
"stock_prediction should keep sequence length and hidden size bounded for interactive runs"
black_box(stock_prediction::DEMO_SEQUENCE_LENGTH) <= 5,
"stock_prediction should keep sequence length bounded for interactive runs"
);
assert!(
source.contains(".with_config(demo_training_config())"),
"stock_prediction should apply the bounded demo config before training"
black_box(stock_prediction::DEMO_HIDDEN_SIZE) <= 8,
"stock_prediction should keep hidden size bounded for interactive runs"
);
}

#[test]
fn batch_processing_example_keeps_scalability_demo_bounded() {
let source = include_str!("../examples/batch_processing_example.rs");

for expected in [
"trainer1.config.epochs = 5",
"trainer2.config.epochs = 5",
"trainer3.config.epochs = 5",
] {
assert!(
source.contains(expected),
"batch benchmark trainer should keep a short epoch budget: {expected}"
);
}
fn stock_prediction_demo_data_is_reproducible() {
let first = stock_prediction::demo_stock_closes(black_box(8));
let second = stock_prediction::demo_stock_closes(black_box(8));

assert_eq!(first, second, "seeded demo data should be reproducible");
assert!(
source.contains("trainer.config.epochs = 3"),
"scalability loop should keep a short epoch budget"
first.windows(2).any(|window| window[0] != window[1]),
"demo data should still vary across generated days"
);
}
Loading