Skip to content

Commit 941f61a

Browse files
authored
test: bound example training runs (#20)
1 parent d6a9910 commit 941f61a

3 files changed

Lines changed: 107 additions & 16 deletions

File tree

examples/stock_prediction.rs

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,20 @@
77
#![allow(unused_comparisons)]
88

99
use ndarray::{arr2, Array2};
10+
use rand::{rngs::StdRng, Rng, SeedableRng};
1011
use rust_lstm::loss::MSELoss;
1112
use rust_lstm::models::lstm_network::LSTMNetwork;
1213
use rust_lstm::optimizers::Adam;
13-
use rust_lstm::training::LSTMTrainer;
14+
use rust_lstm::training::{LSTMTrainer, TrainingConfig};
15+
16+
fn demo_training_config() -> TrainingConfig {
17+
TrainingConfig {
18+
epochs: 2,
19+
print_every: 1,
20+
log_lr_changes: false,
21+
..TrainingConfig::default()
22+
}
23+
}
1424

1525
/// Stock data point with OHLCV (Open, High, Low, Close, Volume)
1626
#[derive(Debug, Clone)]
@@ -140,7 +150,8 @@ impl StockPredictor {
140150
// Create trainer with Adam optimizer
141151
let loss_function = MSELoss;
142152
let optimizer = Adam::new(0.001);
143-
let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer);
153+
let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer)
154+
.with_config(demo_training_config());
144155

145156
// Train the model
146157
trainer.train(train_data, Some(val_data));
@@ -178,6 +189,7 @@ impl StockPredictor {
178189

179190
/// Generate synthetic stock data for demonstration
180191
fn generate_stock_data(days: usize) -> Vec<StockData> {
192+
let mut rng = StdRng::seed_from_u64(42);
181193
let mut data = Vec::new();
182194
let mut price = 100.0;
183195
let volume_base = 1_000_000.0;
@@ -186,19 +198,19 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
186198
// Random walk with trend and volatility
187199
let trend = 0.001; // Slight upward trend
188200
let volatility = 0.02;
189-
let random_change = (rand::random::<f64>() - 0.5) * volatility;
201+
let random_change = (rng.gen::<f64>() - 0.5) * volatility;
190202

191203
price *= 1.0 + trend + random_change;
192204
price = price.max(1.0); // Prevent negative prices
193205

194206
// Generate OHLC based on closing price
195207
let daily_volatility = 0.005;
196-
let high = price * (1.0 + rand::random::<f64>() * daily_volatility);
197-
let low = price * (1.0 - rand::random::<f64>() * daily_volatility);
198-
let open = low + (high - low) * rand::random::<f64>();
208+
let high = price * (1.0 + rng.gen::<f64>() * daily_volatility);
209+
let low = price * (1.0 - rng.gen::<f64>() * daily_volatility);
210+
let open = low + (high - low) * rng.gen::<f64>();
199211

200212
// Volume with some correlation to price movement
201-
let volume_factor = 0.8 + 0.4 * rand::random::<f64>();
213+
let volume_factor = 0.8 + 0.4 * rng.gen::<f64>();
202214
let volume = volume_base * volume_factor;
203215

204216
data.push(StockData {
@@ -217,9 +229,10 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
217229
fn main() {
218230
println!("🏦 Stock Price Prediction with LSTM");
219231
println!("=====================================\n");
232+
println!("This bounded demo favors quick execution over prediction quality.\n");
220233

221234
// Generate synthetic stock data (in practice, you'd load real data)
222-
let stock_data = generate_stock_data(500); // 500 days of data
235+
let stock_data = generate_stock_data(80); // Bounded synthetic dataset for a quick demo
223236
println!(
224237
"📈 Generated {} days of synthetic stock data",
225238
stock_data.len()
@@ -237,16 +250,16 @@ fn main() {
237250
}
238251

239252
// Create and train predictor
240-
let mut predictor = StockPredictor::new(20, 50); // 20-day sequences, 50 hidden units
253+
let mut predictor = StockPredictor::new(5, 8); // 5-day sequences, 8 hidden units
241254
predictor.train(&stock_data, 0.2); // 80% train, 20% validation
242255

243256
// Make predictions on recent data
244257
println!("\n🔮 Making predictions...");
245-
let recent_data = &stock_data[stock_data.len() - 30..]; // Last 30 days
258+
let recent_data = &stock_data[stock_data.len() - 10..]; // Last 10 days
246259

247-
for i in 20..25 {
248-
// Predict for days 21-25 of recent data
249-
let input_data = &recent_data[i - 20..i];
260+
for i in 5..10 {
261+
// Predict for days 6-10 of recent data
262+
let input_data = &recent_data[i - 5..i];
250263
if let Some(predicted_price) = predictor.predict_next_price(input_data) {
251264
let actual_price = recent_data[i].close;
252265
let error = (predicted_price - actual_price).abs();

examples/training_example.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,16 @@ use ndarray::{arr2, Array2};
1010
use rust_lstm::loss::MSELoss;
1111
use rust_lstm::models::lstm_network::LSTMNetwork;
1212
use rust_lstm::optimizers::{Adam, SGD};
13-
use rust_lstm::training::LSTMTrainer;
13+
use rust_lstm::training::{LSTMTrainer, TrainingConfig};
14+
15+
fn demo_training_config() -> TrainingConfig {
16+
TrainingConfig {
17+
epochs: 5,
18+
print_every: 1,
19+
log_lr_changes: false,
20+
..TrainingConfig::default()
21+
}
22+
}
1423

1524
/// Generate sine wave training data for sequence prediction
1625
fn generate_sine_data(
@@ -87,7 +96,8 @@ fn main() {
8796
// Training with SGD
8897
println!("Training with SGD optimizer:");
8998
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
90-
let mut trainer_sgd = LSTMTrainer::new(network, MSELoss, SGD::new(0.01));
99+
let mut trainer_sgd =
100+
LSTMTrainer::new(network, MSELoss, SGD::new(0.01)).with_config(demo_training_config());
91101

92102
trainer_sgd.train(&train_data, Some(&val_data));
93103

@@ -109,7 +119,8 @@ fn main() {
109119
// Training with Adam
110120
println!("Training with Adam optimizer:");
111121
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
112-
let mut trainer_adam = LSTMTrainer::new(network, MSELoss, Adam::new(0.001));
122+
let mut trainer_adam =
123+
LSTMTrainer::new(network, MSELoss, Adam::new(0.001)).with_config(demo_training_config());
113124

114125
trainer_adam.train(&train_data, Some(&val_data));
115126

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#[test]
2+
fn training_example_uses_bounded_demo_epochs() {
3+
let source = include_str!("../examples/training_example.rs");
4+
5+
assert!(
6+
source.contains("epochs: 5"),
7+
"training_example should keep interactive demo training bounded"
8+
);
9+
assert!(
10+
source.contains("fn demo_training_config() -> TrainingConfig"),
11+
"training_example should centralize its bounded demo training config"
12+
);
13+
assert!(
14+
source
15+
.matches(".with_config(demo_training_config())")
16+
.count()
17+
>= 2,
18+
"both optimizer demonstrations should apply the bounded demo config"
19+
);
20+
}
21+
22+
#[test]
23+
fn stock_prediction_example_uses_bounded_demo_epochs() {
24+
let source = include_str!("../examples/stock_prediction.rs");
25+
26+
assert!(
27+
source.contains("epochs: 2"),
28+
"stock_prediction should avoid the default 100-epoch training config"
29+
);
30+
assert!(
31+
source.contains("StdRng::seed_from_u64(42)"),
32+
"stock_prediction should keep synthetic data generation reproducible"
33+
);
34+
assert!(
35+
source.contains("generate_stock_data(80)"),
36+
"stock_prediction should keep its synthetic dataset bounded for interactive runs"
37+
);
38+
assert!(
39+
source.contains("StockPredictor::new(5, 8)"),
40+
"stock_prediction should keep sequence length and hidden size bounded for interactive runs"
41+
);
42+
assert!(
43+
source.contains(".with_config(demo_training_config())"),
44+
"stock_prediction should apply the bounded demo config before training"
45+
);
46+
}
47+
48+
#[test]
49+
fn batch_processing_example_keeps_scalability_demo_bounded() {
50+
let source = include_str!("../examples/batch_processing_example.rs");
51+
52+
for expected in [
53+
"trainer1.config.epochs = 5",
54+
"trainer2.config.epochs = 5",
55+
"trainer3.config.epochs = 5",
56+
] {
57+
assert!(
58+
source.contains(expected),
59+
"batch benchmark trainer should keep a short epoch budget: {expected}"
60+
);
61+
}
62+
63+
assert!(
64+
source.contains("trainer.config.epochs = 3"),
65+
"scalability loop should keep a short epoch budget"
66+
);
67+
}

0 commit comments

Comments
 (0)