Skip to content

Commit 0395bb0

Browse files
committed
test: bound example training runs
1 parent d6a9910 commit 0395bb0

3 files changed

Lines changed: 153 additions & 21 deletions

File tree

examples/stock_prediction.rs

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,31 @@
77
#![allow(unused_comparisons)]
88

99
use ndarray::{arr2, Array2};
10+
use rand::{rngs::StdRng, Rng, SeedableRng};
1011
use rust_lstm::loss::MSELoss;
1112
use rust_lstm::models::lstm_network::LSTMNetwork;
1213
use rust_lstm::optimizers::Adam;
13-
use rust_lstm::training::LSTMTrainer;
14+
use rust_lstm::training::{LSTMTrainer, TrainingConfig};
15+
16+
pub(crate) const DEMO_EPOCHS: usize = 2;
17+
pub(crate) const DEMO_PRINT_EVERY: usize = 1;
18+
pub(crate) const DEMO_STOCK_DAYS: usize = 80;
19+
pub(crate) const DEMO_SEQUENCE_LENGTH: usize = 5;
20+
pub(crate) const DEMO_HIDDEN_SIZE: usize = 8;
21+
pub(crate) const DEMO_RNG_SEED: u64 = 42;
22+
23+
pub(crate) fn demo_training_config() -> TrainingConfig {
24+
TrainingConfig {
25+
epochs: DEMO_EPOCHS,
26+
print_every: DEMO_PRINT_EVERY,
27+
log_lr_changes: false,
28+
..TrainingConfig::default()
29+
}
30+
}
31+
32+
pub(crate) fn demo_stock_trainer(network: LSTMNetwork) -> LSTMTrainer<MSELoss, Adam> {
33+
LSTMTrainer::new(network, MSELoss, Adam::new(0.001)).with_config(demo_training_config())
34+
}
1435

1536
/// Stock data point with OHLCV (Open, High, Low, Close, Volume)
1637
#[derive(Debug, Clone)]
@@ -137,10 +158,7 @@ impl StockPredictor {
137158
val_data.len()
138159
);
139160

140-
// Create trainer with Adam optimizer
141-
let loss_function = MSELoss;
142-
let optimizer = Adam::new(0.001);
143-
let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer);
161+
let mut trainer = demo_stock_trainer(self.network.clone());
144162

145163
// Train the model
146164
trainer.train(train_data, Some(val_data));
@@ -178,6 +196,7 @@ impl StockPredictor {
178196

179197
/// Generate synthetic stock data for demonstration
180198
fn generate_stock_data(days: usize) -> Vec<StockData> {
199+
let mut rng = StdRng::seed_from_u64(DEMO_RNG_SEED);
181200
let mut data = Vec::new();
182201
let mut price = 100.0;
183202
let volume_base = 1_000_000.0;
@@ -186,19 +205,19 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
186205
// Random walk with trend and volatility
187206
let trend = 0.001; // Slight upward trend
188207
let volatility = 0.02;
189-
let random_change = (rand::random::<f64>() - 0.5) * volatility;
208+
let random_change = (rng.gen::<f64>() - 0.5) * volatility;
190209

191210
price *= 1.0 + trend + random_change;
192211
price = price.max(1.0); // Prevent negative prices
193212

194213
// Generate OHLC based on closing price
195214
let daily_volatility = 0.005;
196-
let high = price * (1.0 + rand::random::<f64>() * daily_volatility);
197-
let low = price * (1.0 - rand::random::<f64>() * daily_volatility);
198-
let open = low + (high - low) * rand::random::<f64>();
215+
let high = price * (1.0 + rng.gen::<f64>() * daily_volatility);
216+
let low = price * (1.0 - rng.gen::<f64>() * daily_volatility);
217+
let open = low + (high - low) * rng.gen::<f64>();
199218

200219
// Volume with some correlation to price movement
201-
let volume_factor = 0.8 + 0.4 * rand::random::<f64>();
220+
let volume_factor = 0.8 + 0.4 * rng.gen::<f64>();
202221
let volume = volume_base * volume_factor;
203222

204223
data.push(StockData {
@@ -214,12 +233,21 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
214233
data
215234
}
216235

236+
#[allow(dead_code)]
237+
pub(crate) fn demo_stock_closes(days: usize) -> Vec<f64> {
238+
generate_stock_data(days)
239+
.into_iter()
240+
.map(|stock| stock.close)
241+
.collect()
242+
}
243+
217244
fn main() {
218245
println!("🏦 Stock Price Prediction with LSTM");
219246
println!("=====================================\n");
247+
println!("This bounded demo favors quick execution over prediction quality.\n");
220248

221249
// Generate synthetic stock data (in practice, you'd load real data)
222-
let stock_data = generate_stock_data(500); // 500 days of data
250+
let stock_data = generate_stock_data(DEMO_STOCK_DAYS); // Bounded synthetic dataset for a quick demo
223251
println!(
224252
"📈 Generated {} days of synthetic stock data",
225253
stock_data.len()
@@ -237,16 +265,16 @@ fn main() {
237265
}
238266

239267
// Create and train predictor
240-
let mut predictor = StockPredictor::new(20, 50); // 20-day sequences, 50 hidden units
268+
let mut predictor = StockPredictor::new(DEMO_SEQUENCE_LENGTH, DEMO_HIDDEN_SIZE);
241269
predictor.train(&stock_data, 0.2); // 80% train, 20% validation
242270

243271
// Make predictions on recent data
244272
println!("\n🔮 Making predictions...");
245-
let recent_data = &stock_data[stock_data.len() - 30..]; // Last 30 days
273+
let recent_window = DEMO_SEQUENCE_LENGTH * 2;
274+
let recent_data = &stock_data[stock_data.len() - recent_window..];
246275

247-
for i in 20..25 {
248-
// Predict for days 21-25 of recent data
249-
let input_data = &recent_data[i - 20..i];
276+
for i in DEMO_SEQUENCE_LENGTH..recent_window {
277+
let input_data = &recent_data[i - DEMO_SEQUENCE_LENGTH..i];
250278
if let Some(predicted_price) = predictor.predict_next_price(input_data) {
251279
let actual_price = recent_data[i].close;
252280
let error = (predicted_price - actual_price).abs();

examples/training_example.rs

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,45 @@ use ndarray::{arr2, Array2};
1010
use rust_lstm::loss::MSELoss;
1111
use rust_lstm::models::lstm_network::LSTMNetwork;
1212
use rust_lstm::optimizers::{Adam, SGD};
13-
use rust_lstm::training::LSTMTrainer;
13+
use rust_lstm::training::{LSTMTrainer, TrainingConfig};
14+
15+
pub(crate) const DEMO_EPOCHS: usize = 5;
16+
pub(crate) const DEMO_PRINT_EVERY: usize = 1;
17+
18+
pub(crate) fn demo_training_config() -> TrainingConfig {
19+
TrainingConfig {
20+
epochs: DEMO_EPOCHS,
21+
print_every: DEMO_PRINT_EVERY,
22+
log_lr_changes: false,
23+
..TrainingConfig::default()
24+
}
25+
}
26+
27+
pub(crate) fn demo_sgd_trainer(
28+
input_size: usize,
29+
hidden_size: usize,
30+
num_layers: usize,
31+
) -> LSTMTrainer<MSELoss, SGD> {
32+
LSTMTrainer::new(
33+
LSTMNetwork::new(input_size, hidden_size, num_layers),
34+
MSELoss,
35+
SGD::new(0.01),
36+
)
37+
.with_config(demo_training_config())
38+
}
39+
40+
pub(crate) fn demo_adam_trainer(
41+
input_size: usize,
42+
hidden_size: usize,
43+
num_layers: usize,
44+
) -> LSTMTrainer<MSELoss, Adam> {
45+
LSTMTrainer::new(
46+
LSTMNetwork::new(input_size, hidden_size, num_layers),
47+
MSELoss,
48+
Adam::new(0.001),
49+
)
50+
.with_config(demo_training_config())
51+
}
1452

1553
/// Generate sine wave training data for sequence prediction
1654
fn generate_sine_data(
@@ -86,8 +124,7 @@ fn main() {
86124

87125
// Training with SGD
88126
println!("Training with SGD optimizer:");
89-
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
90-
let mut trainer_sgd = LSTMTrainer::new(network, MSELoss, SGD::new(0.01));
127+
let mut trainer_sgd = demo_sgd_trainer(input_size, hidden_size, num_layers);
91128

92129
trainer_sgd.train(&train_data, Some(&val_data));
93130

@@ -108,8 +145,7 @@ fn main() {
108145

109146
// Training with Adam
110147
println!("Training with Adam optimizer:");
111-
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
112-
let mut trainer_adam = LSTMTrainer::new(network, MSELoss, Adam::new(0.001));
148+
let mut trainer_adam = demo_adam_trainer(input_size, hidden_size, num_layers);
113149

114150
trainer_adam.train(&train_data, Some(&val_data));
115151

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#[allow(dead_code)]
2+
#[path = "../examples/training_example.rs"]
3+
mod training_example;
4+
5+
#[allow(dead_code)]
6+
#[path = "../examples/stock_prediction.rs"]
7+
mod stock_prediction;
8+
9+
use std::hint::black_box;
10+
11+
#[test]
12+
fn training_example_applies_bounded_config_to_both_demo_trainers() {
13+
let sgd_trainer = training_example::demo_sgd_trainer(1, 10, 1);
14+
let adam_trainer = training_example::demo_adam_trainer(1, 10, 1);
15+
16+
for epochs in [sgd_trainer.config.epochs, adam_trainer.config.epochs] {
17+
assert!(
18+
black_box(epochs) <= 5,
19+
"training_example should keep both optimizer demos bounded"
20+
);
21+
}
22+
23+
for print_every in [
24+
sgd_trainer.config.print_every,
25+
adam_trainer.config.print_every,
26+
] {
27+
assert!(
28+
black_box(print_every) <= black_box(training_example::DEMO_EPOCHS),
29+
"training_example should report progress without exceeding the epoch budget"
30+
);
31+
}
32+
}
33+
34+
#[test]
35+
fn stock_prediction_applies_bounded_config_to_demo_trainer() {
36+
let network =
37+
rust_lstm::models::lstm_network::LSTMNetwork::new(5, stock_prediction::DEMO_HIDDEN_SIZE, 2);
38+
let trainer = stock_prediction::demo_stock_trainer(network);
39+
40+
assert!(
41+
black_box(trainer.config.epochs) <= 2,
42+
"stock_prediction should avoid the default 100-epoch training config"
43+
);
44+
assert!(
45+
black_box(stock_prediction::DEMO_STOCK_DAYS) <= 80,
46+
"stock_prediction should keep its synthetic dataset bounded for interactive runs"
47+
);
48+
assert!(
49+
black_box(stock_prediction::DEMO_SEQUENCE_LENGTH) <= 5,
50+
"stock_prediction should keep sequence length bounded for interactive runs"
51+
);
52+
assert!(
53+
black_box(stock_prediction::DEMO_HIDDEN_SIZE) <= 8,
54+
"stock_prediction should keep hidden size bounded for interactive runs"
55+
);
56+
}
57+
58+
#[test]
59+
fn stock_prediction_demo_data_is_reproducible() {
60+
let first = stock_prediction::demo_stock_closes(black_box(8));
61+
let second = stock_prediction::demo_stock_closes(black_box(8));
62+
63+
assert_eq!(first, second, "seeded demo data should be reproducible");
64+
assert!(
65+
first.windows(2).any(|window| window[0] != window[1]),
66+
"demo data should still vary across generated days"
67+
);
68+
}

0 commit comments

Comments
 (0)