Skip to content

Commit 3845fba

Browse files
authored
test: bound advanced LR scheduling example (#24)
1 parent 4f981db commit 3845fba

2 files changed

Lines changed: 205 additions & 88 deletions

File tree

examples/advanced_lr_scheduling.rs

Lines changed: 109 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,29 @@ use rust_lstm::{
1212
ScheduledLSTMTrainer, ScheduledOptimizer, StepLR, TrainingConfig, WarmupScheduler,
1313
};
1414

15+
pub const DEMO_TRAIN_SEQUENCES: usize = 12;
16+
pub const DEMO_VAL_SEQUENCES: usize = 4;
17+
pub const DEMO_SEQUENCE_LENGTH: usize = 6;
18+
pub const DEMO_HIDDEN_SIZE: usize = 4;
19+
pub const DEMO_ADVANCED_HIDDEN_SIZE: usize = 6;
20+
pub const DEMO_POLYNOMIAL_EPOCHS: usize = 5;
21+
pub const DEMO_CYCLICAL_EPOCHS: usize = 5;
22+
pub const DEMO_WARMUP_EPOCHS: usize = 5;
23+
pub const DEMO_ADVANCED_EPOCHS: usize = 5;
24+
pub const DEMO_POLYNOMIAL_ITERS: usize = DEMO_POLYNOMIAL_EPOCHS;
25+
pub const DEMO_CYCLICAL_STEP_SIZE: usize = 2;
26+
pub const DEMO_WARMUP_EPOCH_COUNT: usize = 2;
27+
pub const DEMO_BASE_STEP_SIZE: usize = 2;
28+
pub const DEMO_VISUALIZATION_STEP_SIZE: usize = 2;
29+
pub const DEMO_VISUALIZATION_STEPS: usize = 20;
30+
1531
fn main() {
1632
println!("🚀 Advanced Learning Rate Scheduling for Rust-LSTM");
1733
println!("===================================================\n");
1834

1935
// Generate sample training data
20-
let train_data = generate_sine_wave_data(50, 0.0);
21-
let val_data = generate_sine_wave_data(10, 1000.0);
36+
let train_data = generate_sine_wave_data(DEMO_TRAIN_SEQUENCES, 0.0);
37+
let val_data = generate_sine_wave_data(DEMO_VAL_SEQUENCES, 1000.0);
2238

2339
// 1. Polynomial Decay Example
2440
polynomial_decay_example(&train_data, &val_data);
@@ -43,24 +59,18 @@ fn polynomial_decay_example(
4359
println!("1️⃣ Polynomial Decay Example");
4460
println!(" Smoothly decays LR using polynomial function\n");
4561

46-
let network = LSTMNetwork::new(1, 8, 1);
62+
let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 1);
4763

4864
let loss_function = MSELoss;
4965
let scheduled_optimizer = ScheduledOptimizer::polynomial(
5066
Adam::new(0.01),
51-
0.01, // base_lr
52-
25, // total_iters
53-
2.0, // power
54-
0.001, // end_lr
67+
0.01, // base_lr
68+
DEMO_POLYNOMIAL_ITERS, // total_iters
69+
2.0, // power
70+
0.001, // end_lr
5571
);
5672

57-
let config = TrainingConfig {
58-
epochs: 30,
59-
print_every: 5,
60-
clip_gradient: Some(1.0),
61-
log_lr_changes: true,
62-
early_stopping: None,
63-
};
73+
let config = polynomial_decay_training_config();
6474

6575
let mut trainer =
6676
ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config);
@@ -80,23 +90,17 @@ fn cyclical_lr_examples(
8090

8191
// 2a. Triangular Cyclical LR
8292
println!("2a. Triangular Cyclical LR");
83-
let network = LSTMNetwork::new(1, 8, 1);
93+
let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 1);
8494

8595
let loss_function = MSELoss;
8696
let scheduled_optimizer = ScheduledOptimizer::cyclical(
8797
Adam::new(0.001),
88-
0.001, // base_lr
89-
0.01, // max_lr
90-
8, // step_size
98+
0.001, // base_lr
99+
0.01, // max_lr
100+
DEMO_CYCLICAL_STEP_SIZE, // step_size
91101
);
92102

93-
let config = TrainingConfig {
94-
epochs: 25,
95-
print_every: 5,
96-
clip_gradient: Some(1.0),
97-
log_lr_changes: false, // Too frequent for cyclical
98-
early_stopping: None,
99-
};
103+
let config = cyclical_lr_training_config();
100104

101105
let mut trainer =
102106
ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config);
@@ -106,23 +110,17 @@ fn cyclical_lr_examples(
106110

107111
// 2b. Triangular2 Cyclical LR (halving amplitude each cycle)
108112
println!("2b. Triangular2 Cyclical LR (halving amplitude each cycle)");
109-
let network = LSTMNetwork::new(1, 8, 1);
113+
let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 1);
110114

111115
let loss_function = MSELoss;
112116
let scheduled_optimizer = ScheduledOptimizer::cyclical_triangular2(
113117
Adam::new(0.001),
114-
0.001, // base_lr
115-
0.01, // max_lr
116-
8, // step_size
118+
0.001, // base_lr
119+
0.01, // max_lr
120+
DEMO_CYCLICAL_STEP_SIZE, // step_size
117121
);
118122

119-
let config2 = TrainingConfig {
120-
epochs: 25,
121-
print_every: 5,
122-
clip_gradient: Some(1.0),
123-
log_lr_changes: false,
124-
early_stopping: None,
125-
};
123+
let config2 = cyclical_lr_training_config();
126124

127125
let mut trainer =
128126
ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config2);
@@ -132,24 +130,18 @@ fn cyclical_lr_examples(
132130

133131
// 2c. ExpRange Cyclical LR (exponential scaling)
134132
println!("2c. ExpRange Cyclical LR (exponential scaling)");
135-
let network = LSTMNetwork::new(1, 8, 1);
133+
let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 1);
136134

137135
let loss_function = MSELoss;
138136
let scheduled_optimizer = ScheduledOptimizer::cyclical_exp_range(
139137
Adam::new(0.001),
140-
0.001, // base_lr
141-
0.01, // max_lr
142-
8, // step_size
143-
0.95, // gamma
138+
0.001, // base_lr
139+
0.01, // max_lr
140+
DEMO_CYCLICAL_STEP_SIZE, // step_size
141+
0.95, // gamma
144142
);
145143

146-
let config3 = TrainingConfig {
147-
epochs: 25,
148-
print_every: 5,
149-
clip_gradient: Some(1.0),
150-
log_lr_changes: false,
151-
early_stopping: None,
152-
};
144+
let config3 = cyclical_lr_training_config();
153145

154146
let mut trainer =
155147
ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config3);
@@ -167,26 +159,15 @@ fn warmup_scheduler_example(
167159
println!("3️⃣ Warmup Scheduler Example");
168160
println!(" Gradually increases LR during warmup, then applies base scheduler\n");
169161

170-
let network = LSTMNetwork::new(1, 8, 1);
162+
let network = LSTMNetwork::new(1, DEMO_HIDDEN_SIZE, 1);
171163

172-
// Create warmup scheduler with step decay after warmup
173-
let base_scheduler = StepLR::new(10, 0.5); // Reduce by half every 10 epochs
174-
let warmup_scheduler = WarmupScheduler::new(
175-
5, // warmup_epochs
176-
base_scheduler, // base_scheduler
177-
0.001, // warmup_start_lr
178-
);
164+
let base_scheduler = StepLR::new(DEMO_BASE_STEP_SIZE, 0.5);
165+
let warmup_scheduler = WarmupScheduler::new(DEMO_WARMUP_EPOCH_COUNT, base_scheduler, 0.001);
179166

180167
let loss_function = MSELoss;
181168
let scheduled_optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01);
182169

183-
let config = TrainingConfig {
184-
epochs: 30,
185-
print_every: 3,
186-
clip_gradient: Some(1.0),
187-
log_lr_changes: true,
188-
early_stopping: None,
189-
};
170+
let config = warmup_scheduler_training_config();
190171

191172
let mut trainer =
192173
ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config);
@@ -202,21 +183,27 @@ fn schedule_visualization() {
202183
println!(" ASCII visualization of different schedulers\n");
203184

204185
// Visualize StepLR
205-
println!("StepLR (step_size=10, gamma=0.5):");
206-
let step_scheduler = StepLR::new(10, 0.5);
207-
LRScheduleVisualizer::print_schedule(step_scheduler, 0.01, 50, 60, 10);
186+
println!("StepLR (step_size=2, gamma=0.5):");
187+
let step_scheduler = StepLR::new(DEMO_VISUALIZATION_STEP_SIZE, 0.5);
188+
LRScheduleVisualizer::print_schedule(step_scheduler, 0.01, DEMO_VISUALIZATION_STEPS, 40, 5);
208189
println!();
209190

210191
// Visualize PolynomialLR
211192
println!("PolynomialLR (power=2.0, end_lr=0.001):");
212-
let poly_scheduler = PolynomialLR::new(50, 2.0, 0.001);
213-
LRScheduleVisualizer::print_schedule(poly_scheduler, 0.01, 50, 60, 10);
193+
let poly_scheduler = PolynomialLR::new(DEMO_VISUALIZATION_STEPS, 2.0, 0.001);
194+
LRScheduleVisualizer::print_schedule(poly_scheduler, 0.01, DEMO_VISUALIZATION_STEPS, 40, 5);
214195
println!();
215196

216197
// Visualize CyclicalLR
217-
println!("CyclicalLR Triangular (base_lr=0.001, max_lr=0.01, step_size=8):");
218-
let cyclical_scheduler = CyclicalLR::new(0.001, 0.01, 8);
219-
LRScheduleVisualizer::print_schedule(cyclical_scheduler, 0.001, 50, 60, 10);
198+
println!("CyclicalLR Triangular (base_lr=0.001, max_lr=0.01, step_size=2):");
199+
let cyclical_scheduler = CyclicalLR::new(0.001, 0.01, DEMO_CYCLICAL_STEP_SIZE);
200+
LRScheduleVisualizer::print_schedule(
201+
cyclical_scheduler,
202+
0.001,
203+
DEMO_VISUALIZATION_STEPS,
204+
40,
205+
5,
206+
);
220207
println!();
221208

222209
println!("----------------------------------------\n");
@@ -230,25 +217,20 @@ fn advanced_training_example(
230217
println!(" Warmup + Cyclical LR + Dropout + Gradient Clipping\n");
231218

232219
// Create network with dropout
233-
let network = LSTMNetwork::new(1, 16, 1)
220+
let network = LSTMNetwork::new(1, DEMO_ADVANCED_HIDDEN_SIZE, 1)
234221
.with_input_dropout(0.1, true) // Variational dropout
235222
.with_recurrent_dropout(0.2, true) // Variational recurrent dropout
236223
.with_output_dropout(0.1); // Standard output dropout
237224

238225
// Create warmup scheduler with cyclical base scheduler
239-
let base_scheduler = CyclicalLR::new(0.001, 0.01, 10).with_mode(CyclicalMode::Triangular2);
240-
let warmup_scheduler = WarmupScheduler::new(5, base_scheduler, 0.0001);
226+
let base_scheduler =
227+
CyclicalLR::new(0.001, 0.01, DEMO_CYCLICAL_STEP_SIZE).with_mode(CyclicalMode::Triangular2);
228+
let warmup_scheduler = WarmupScheduler::new(DEMO_WARMUP_EPOCH_COUNT, base_scheduler, 0.0001);
241229

242230
let loss_function = MSELoss;
243231
let scheduled_optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01);
244232

245-
let config = TrainingConfig {
246-
epochs: 40,
247-
print_every: 5,
248-
clip_gradient: Some(1.0), // Gradient clipping
249-
log_lr_changes: false, // Too frequent for cyclical
250-
early_stopping: None,
251-
};
233+
let config = advanced_training_config();
252234

253235
let mut trainer =
254236
ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer).with_config(config);
@@ -272,18 +254,17 @@ fn advanced_training_example(
272254
println!("\n✅ Advanced training complete!");
273255
}
274256

275-
fn generate_sine_wave_data(
257+
pub fn generate_sine_wave_data(
276258
num_sequences: usize,
277259
offset: f64,
278260
) -> Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)> {
279261
let mut data = Vec::new();
280262

281263
for i in 0..num_sequences {
282-
let sequence_length = 8;
283264
let mut inputs = Vec::new();
284265
let mut targets = Vec::new();
285266

286-
for t in 0..sequence_length {
267+
for t in 0..DEMO_SEQUENCE_LENGTH {
287268
let x = (offset + i as f64 * 0.1 + t as f64 * 0.2).sin();
288269
let y = (offset + i as f64 * 0.1 + (t + 1) as f64 * 0.2).sin();
289270

@@ -297,19 +278,59 @@ fn generate_sine_wave_data(
297278
data
298279
}
299280

281+
pub fn polynomial_decay_training_config() -> TrainingConfig {
282+
TrainingConfig {
283+
epochs: DEMO_POLYNOMIAL_EPOCHS,
284+
print_every: 1,
285+
clip_gradient: Some(1.0),
286+
log_lr_changes: true,
287+
early_stopping: None,
288+
}
289+
}
290+
291+
pub fn cyclical_lr_training_config() -> TrainingConfig {
292+
TrainingConfig {
293+
epochs: DEMO_CYCLICAL_EPOCHS,
294+
print_every: 1,
295+
clip_gradient: Some(1.0),
296+
log_lr_changes: false,
297+
early_stopping: None,
298+
}
299+
}
300+
301+
pub fn warmup_scheduler_training_config() -> TrainingConfig {
302+
TrainingConfig {
303+
epochs: DEMO_WARMUP_EPOCHS,
304+
print_every: 1,
305+
clip_gradient: Some(1.0),
306+
log_lr_changes: true,
307+
early_stopping: None,
308+
}
309+
}
310+
311+
pub fn advanced_training_config() -> TrainingConfig {
312+
TrainingConfig {
313+
epochs: DEMO_ADVANCED_EPOCHS,
314+
print_every: 1,
315+
clip_gradient: Some(1.0),
316+
log_lr_changes: false,
317+
early_stopping: None,
318+
}
319+
}
320+
300321
#[cfg(test)]
301322
mod tests {
302323
use super::*;
303-
use rust_lstm::SGD;
304324

305325
#[test]
306326
fn test_advanced_schedulers() {
307327
// Test polynomial scheduler
308328
let poly_scheduler = PolynomialLR::new(100, 2.0, 0.01);
309-
let schedule = LRScheduleVisualizer::generate_schedule(poly_scheduler, 0.1, 100);
310-
assert_eq!(schedule.len(), 100);
329+
let schedule = LRScheduleVisualizer::generate_schedule(poly_scheduler, 0.1, 101);
330+
assert_eq!(schedule.len(), 101);
311331
assert_eq!(schedule[0].1, 0.1);
312-
assert!((schedule[99].1 - 0.01).abs() < 1e-10);
332+
assert!(schedule[99].1 < schedule[0].1);
333+
assert!((schedule[100].1 - 0.01).abs() < 1e-10);
313334

314335
// Test cyclical scheduler
315336
let cyclical_scheduler = CyclicalLR::new(0.01, 0.1, 10);

0 commit comments

Comments
 (0)