Skip to content

Commit 2895c4a

Browse files
authored
test: bound example training runs (#21)
1 parent 941f61a commit 2895c4a

3 files changed

Lines changed: 109 additions & 68 deletions

File tree

examples/stock_prediction.rs

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,26 @@ use rust_lstm::models::lstm_network::LSTMNetwork;
1313
use rust_lstm::optimizers::Adam;
1414
use rust_lstm::training::{LSTMTrainer, TrainingConfig};
1515

16-
fn demo_training_config() -> TrainingConfig {
16+
pub(crate) const DEMO_EPOCHS: usize = 2;
17+
pub(crate) const DEMO_PRINT_EVERY: usize = 1;
18+
pub(crate) const DEMO_STOCK_DAYS: usize = 80;
19+
pub(crate) const DEMO_SEQUENCE_LENGTH: usize = 5;
20+
pub(crate) const DEMO_HIDDEN_SIZE: usize = 8;
21+
pub(crate) const DEMO_RNG_SEED: u64 = 42;
22+
23+
pub(crate) fn demo_training_config() -> TrainingConfig {
1724
TrainingConfig {
18-
epochs: 2,
19-
print_every: 1,
25+
epochs: DEMO_EPOCHS,
26+
print_every: DEMO_PRINT_EVERY,
2027
log_lr_changes: false,
2128
..TrainingConfig::default()
2229
}
2330
}
2431

32+
pub(crate) fn demo_stock_trainer(network: LSTMNetwork) -> LSTMTrainer<MSELoss, Adam> {
33+
LSTMTrainer::new(network, MSELoss, Adam::new(0.001)).with_config(demo_training_config())
34+
}
35+
2536
/// Stock data point with OHLCV (Open, High, Low, Close, Volume)
2637
#[derive(Debug, Clone)]
2738
#[allow(dead_code)]
@@ -147,11 +158,7 @@ impl StockPredictor {
147158
val_data.len()
148159
);
149160

150-
// Create trainer with Adam optimizer
151-
let loss_function = MSELoss;
152-
let optimizer = Adam::new(0.001);
153-
let mut trainer = LSTMTrainer::new(self.network.clone(), loss_function, optimizer)
154-
.with_config(demo_training_config());
161+
let mut trainer = demo_stock_trainer(self.network.clone());
155162

156163
// Train the model
157164
trainer.train(train_data, Some(val_data));
@@ -189,7 +196,7 @@ impl StockPredictor {
189196

190197
/// Generate synthetic stock data for demonstration
191198
fn generate_stock_data(days: usize) -> Vec<StockData> {
192-
let mut rng = StdRng::seed_from_u64(42);
199+
let mut rng = StdRng::seed_from_u64(DEMO_RNG_SEED);
193200
let mut data = Vec::new();
194201
let mut price = 100.0;
195202
let volume_base = 1_000_000.0;
@@ -226,13 +233,21 @@ fn generate_stock_data(days: usize) -> Vec<StockData> {
226233
data
227234
}
228235

236+
#[allow(dead_code)]
237+
pub(crate) fn demo_stock_closes(days: usize) -> Vec<f64> {
238+
generate_stock_data(days)
239+
.into_iter()
240+
.map(|stock| stock.close)
241+
.collect()
242+
}
243+
229244
fn main() {
230245
println!("🏦 Stock Price Prediction with LSTM");
231246
println!("=====================================\n");
232247
println!("This bounded demo favors quick execution over prediction quality.\n");
233248

234249
// Generate synthetic stock data (in practice, you'd load real data)
235-
let stock_data = generate_stock_data(80); // Bounded synthetic dataset for a quick demo
250+
let stock_data = generate_stock_data(DEMO_STOCK_DAYS); // Bounded synthetic dataset for a quick demo
236251
println!(
237252
"📈 Generated {} days of synthetic stock data",
238253
stock_data.len()
@@ -250,16 +265,16 @@ fn main() {
250265
}
251266

252267
// Create and train predictor
253-
let mut predictor = StockPredictor::new(5, 8); // 5-day sequences, 8 hidden units
268+
let mut predictor = StockPredictor::new(DEMO_SEQUENCE_LENGTH, DEMO_HIDDEN_SIZE);
254269
predictor.train(&stock_data, 0.2); // 80% train, 20% validation
255270

256271
// Make predictions on recent data
257272
println!("\n🔮 Making predictions...");
258-
let recent_data = &stock_data[stock_data.len() - 10..]; // Last 10 days
273+
let recent_window = DEMO_SEQUENCE_LENGTH * 2;
274+
let recent_data = &stock_data[stock_data.len() - recent_window..];
259275

260-
for i in 5..10 {
261-
// Predict for days 6-10 of recent data
262-
let input_data = &recent_data[i - 5..i];
276+
for i in DEMO_SEQUENCE_LENGTH..recent_window {
277+
let input_data = &recent_data[i - DEMO_SEQUENCE_LENGTH..i];
263278
if let Some(predicted_price) = predictor.predict_next_price(input_data) {
264279
let actual_price = recent_data[i].close;
265280
let error = (predicted_price - actual_price).abs();

examples/training_example.rs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,44 @@ use rust_lstm::models::lstm_network::LSTMNetwork;
1212
use rust_lstm::optimizers::{Adam, SGD};
1313
use rust_lstm::training::{LSTMTrainer, TrainingConfig};
1414

15-
fn demo_training_config() -> TrainingConfig {
15+
pub(crate) const DEMO_EPOCHS: usize = 5;
16+
pub(crate) const DEMO_PRINT_EVERY: usize = 1;
17+
18+
pub(crate) fn demo_training_config() -> TrainingConfig {
1619
TrainingConfig {
17-
epochs: 5,
18-
print_every: 1,
20+
epochs: DEMO_EPOCHS,
21+
print_every: DEMO_PRINT_EVERY,
1922
log_lr_changes: false,
2023
..TrainingConfig::default()
2124
}
2225
}
2326

27+
pub(crate) fn demo_sgd_trainer(
28+
input_size: usize,
29+
hidden_size: usize,
30+
num_layers: usize,
31+
) -> LSTMTrainer<MSELoss, SGD> {
32+
LSTMTrainer::new(
33+
LSTMNetwork::new(input_size, hidden_size, num_layers),
34+
MSELoss,
35+
SGD::new(0.01),
36+
)
37+
.with_config(demo_training_config())
38+
}
39+
40+
pub(crate) fn demo_adam_trainer(
41+
input_size: usize,
42+
hidden_size: usize,
43+
num_layers: usize,
44+
) -> LSTMTrainer<MSELoss, Adam> {
45+
LSTMTrainer::new(
46+
LSTMNetwork::new(input_size, hidden_size, num_layers),
47+
MSELoss,
48+
Adam::new(0.001),
49+
)
50+
.with_config(demo_training_config())
51+
}
52+
2453
/// Generate sine wave training data for sequence prediction
2554
fn generate_sine_data(
2655
num_sequences: usize,
@@ -95,9 +124,7 @@ fn main() {
95124

96125
// Training with SGD
97126
println!("Training with SGD optimizer:");
98-
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
99-
let mut trainer_sgd =
100-
LSTMTrainer::new(network, MSELoss, SGD::new(0.01)).with_config(demo_training_config());
127+
let mut trainer_sgd = demo_sgd_trainer(input_size, hidden_size, num_layers);
101128

102129
trainer_sgd.train(&train_data, Some(&val_data));
103130

@@ -118,9 +145,7 @@ fn main() {
118145

119146
// Training with Adam
120147
println!("Training with Adam optimizer:");
121-
let network = LSTMNetwork::new(input_size, hidden_size, num_layers);
122-
let mut trainer_adam =
123-
LSTMTrainer::new(network, MSELoss, Adam::new(0.001)).with_config(demo_training_config());
148+
let mut trainer_adam = demo_adam_trainer(input_size, hidden_size, num_layers);
124149

125150
trainer_adam.train(&train_data, Some(&val_data));
126151

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,68 @@
1+
#[allow(dead_code)]
2+
#[path = "../examples/training_example.rs"]
3+
mod training_example;
4+
5+
#[allow(dead_code)]
6+
#[path = "../examples/stock_prediction.rs"]
7+
mod stock_prediction;
8+
9+
use std::hint::black_box;
10+
111
#[test]
2-
fn training_example_uses_bounded_demo_epochs() {
3-
let source = include_str!("../examples/training_example.rs");
12+
fn training_example_applies_bounded_config_to_both_demo_trainers() {
13+
let sgd_trainer = training_example::demo_sgd_trainer(1, 10, 1);
14+
let adam_trainer = training_example::demo_adam_trainer(1, 10, 1);
415

5-
assert!(
6-
source.contains("epochs: 5"),
7-
"training_example should keep interactive demo training bounded"
8-
);
9-
assert!(
10-
source.contains("fn demo_training_config() -> TrainingConfig"),
11-
"training_example should centralize its bounded demo training config"
12-
);
13-
assert!(
14-
source
15-
.matches(".with_config(demo_training_config())")
16-
.count()
17-
>= 2,
18-
"both optimizer demonstrations should apply the bounded demo config"
19-
);
16+
for epochs in [sgd_trainer.config.epochs, adam_trainer.config.epochs] {
17+
assert!(
18+
black_box(epochs) <= 5,
19+
"training_example should keep both optimizer demos bounded"
20+
);
21+
}
22+
23+
for print_every in [
24+
sgd_trainer.config.print_every,
25+
adam_trainer.config.print_every,
26+
] {
27+
assert!(
28+
black_box(print_every) <= black_box(training_example::DEMO_EPOCHS),
29+
"training_example should report progress without exceeding the epoch budget"
30+
);
31+
}
2032
}
2133

2234
#[test]
23-
fn stock_prediction_example_uses_bounded_demo_epochs() {
24-
let source = include_str!("../examples/stock_prediction.rs");
35+
fn stock_prediction_applies_bounded_config_to_demo_trainer() {
36+
let network =
37+
rust_lstm::models::lstm_network::LSTMNetwork::new(5, stock_prediction::DEMO_HIDDEN_SIZE, 2);
38+
let trainer = stock_prediction::demo_stock_trainer(network);
2539

2640
assert!(
27-
source.contains("epochs: 2"),
41+
black_box(trainer.config.epochs) <= 2,
2842
"stock_prediction should avoid the default 100-epoch training config"
2943
);
3044
assert!(
31-
source.contains("StdRng::seed_from_u64(42)"),
32-
"stock_prediction should keep synthetic data generation reproducible"
33-
);
34-
assert!(
35-
source.contains("generate_stock_data(80)"),
45+
black_box(stock_prediction::DEMO_STOCK_DAYS) <= 80,
3646
"stock_prediction should keep its synthetic dataset bounded for interactive runs"
3747
);
3848
assert!(
39-
source.contains("StockPredictor::new(5, 8)"),
40-
"stock_prediction should keep sequence length and hidden size bounded for interactive runs"
49+
black_box(stock_prediction::DEMO_SEQUENCE_LENGTH) <= 5,
50+
"stock_prediction should keep sequence length bounded for interactive runs"
4151
);
4252
assert!(
43-
source.contains(".with_config(demo_training_config())"),
44-
"stock_prediction should apply the bounded demo config before training"
53+
black_box(stock_prediction::DEMO_HIDDEN_SIZE) <= 8,
54+
"stock_prediction should keep hidden size bounded for interactive runs"
4555
);
4656
}
4757

4858
#[test]
49-
fn batch_processing_example_keeps_scalability_demo_bounded() {
50-
let source = include_str!("../examples/batch_processing_example.rs");
51-
52-
for expected in [
53-
"trainer1.config.epochs = 5",
54-
"trainer2.config.epochs = 5",
55-
"trainer3.config.epochs = 5",
56-
] {
57-
assert!(
58-
source.contains(expected),
59-
"batch benchmark trainer should keep a short epoch budget: {expected}"
60-
);
61-
}
59+
fn stock_prediction_demo_data_is_reproducible() {
60+
let first = stock_prediction::demo_stock_closes(black_box(8));
61+
let second = stock_prediction::demo_stock_closes(black_box(8));
6262

63+
assert_eq!(first, second, "seeded demo data should be reproducible");
6364
assert!(
64-
source.contains("trainer.config.epochs = 3"),
65-
"scalability loop should keep a short epoch budget"
65+
first.windows(2).any(|window| window[0] != window[1]),
66+
"demo data should still vary across generated days"
6667
);
6768
}

0 commit comments

Comments
 (0)