|
8 | 8 |
|
9 | 9 | use ndarray::Array2; |
10 | 10 | use rust_lstm::{ |
| 11 | + layers::bilstm_network::BiLSTMNetwork, |
11 | 12 | layers::dropout::{Dropout, Zoneout}, |
| 13 | + layers::linear::LinearLayer, |
12 | 14 | layers::peephole_lstm_cell::PeepholeLSTMCell, |
13 | 15 | loss::{CrossEntropyLoss, MAELoss, MSELoss}, |
14 | | - optimizers::{Adam, RMSprop, SGD}, |
15 | | - training::create_basic_trainer, |
16 | | - LSTMNetwork, LSTMTrainer, LayerDropoutConfig, TrainingConfig, |
| 16 | + models::gru_network::GRUNetwork, |
| 17 | + optimizers::{Adam, RMSprop, ScheduledOptimizer, SGD}, |
| 18 | + schedulers::{CyclicalLR, LRScheduleVisualizer, PolynomialLR, WarmupScheduler}, |
| 19 | + training::{ |
| 20 | + create_basic_trainer, create_cosine_annealing_trainer, create_one_cycle_trainer, |
| 21 | + create_step_lr_trainer, |
| 22 | + }, |
| 23 | + EarlyStoppingConfig, EarlyStoppingMetric, LSTMNetwork, LSTMTrainer, LayerDropoutConfig, |
| 24 | + TrainingConfig, |
17 | 25 | }; |
18 | 26 |
|
19 | 27 | #[test] |
@@ -119,6 +127,131 @@ fn test_training_example() { |
119 | 127 | assert_eq!(predictions[0].shape(), &[4, 1]); |
120 | 128 | } |
121 | 129 |
|
| 130 | +#[test] |
| 131 | +fn test_readme_early_stopping_example() { |
| 132 | + let network = LSTMNetwork::new(1, 4, 1); |
| 133 | + |
| 134 | + // Configure early stopping |
| 135 | + let early_stopping = EarlyStoppingConfig { |
| 136 | + patience: 2, |
| 137 | + min_delta: 1e-4, |
| 138 | + restore_best_weights: true, |
| 139 | + monitor: EarlyStoppingMetric::ValidationLoss, |
| 140 | + }; |
| 141 | + |
| 142 | + let config = TrainingConfig { |
| 143 | + epochs: 2, |
| 144 | + early_stopping: Some(early_stopping), |
| 145 | + ..Default::default() |
| 146 | + }; |
| 147 | + |
| 148 | + let mut trainer = create_basic_trainer(network, 0.001).with_config(config); |
| 149 | + let train_data = generate_test_data(); |
| 150 | + let validation_data = generate_test_data(); |
| 151 | + |
| 152 | + trainer.train(&train_data, Some(&validation_data)); |
| 153 | + |
| 154 | + assert_eq!(trainer.config.early_stopping.as_ref().unwrap().patience, 2); |
| 155 | +} |
| 156 | + |
| 157 | +#[test] |
| 158 | +fn test_readme_bilstm_example() { |
| 159 | + let input_size = 3; |
| 160 | + let hidden_size = 5; |
| 161 | + let num_layers = 1; |
| 162 | + |
| 163 | + // BiLSTM with concatenated outputs (output_size = 2 * hidden_size) |
| 164 | + let mut bilstm = BiLSTMNetwork::new_concat(input_size, hidden_size, num_layers); |
| 165 | + |
| 166 | + // Process sequence with both past and future context |
| 167 | + let sequence = vec![ |
| 168 | + Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap(), |
| 169 | + Array2::from_shape_vec((input_size, 1), vec![0.2, -0.4, 0.7]).unwrap(), |
| 170 | + ]; |
| 171 | + let outputs = bilstm.forward_sequence(&sequence); |
| 172 | + |
| 173 | + assert_eq!(outputs.len(), sequence.len()); |
| 174 | + for output in outputs { |
| 175 | + assert_eq!(output.shape(), &[2 * hidden_size, 1]); |
| 176 | + } |
| 177 | +} |
| 178 | + |
| 179 | +#[test] |
| 180 | +fn test_readme_gru_example() { |
| 181 | + let input_size = 3; |
| 182 | + let hidden_size = 5; |
| 183 | + let num_layers = 2; |
| 184 | + |
| 185 | + // Create GRU network (alternative to LSTM) |
| 186 | + let mut gru = GRUNetwork::new(input_size, hidden_size, num_layers) |
| 187 | + .with_input_dropout(0.2, true) |
| 188 | + .with_recurrent_dropout(0.3, true); |
| 189 | + |
| 190 | + let input = Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap(); |
| 191 | + let hidden_states = vec![Array2::zeros((hidden_size, 1)); num_layers]; |
| 192 | + |
| 193 | + // Forward pass returns one hidden state per layer |
| 194 | + let outputs = gru.forward(&input, &hidden_states); |
| 195 | + let output = outputs.last().unwrap(); |
| 196 | + |
| 197 | + assert_eq!(outputs.len(), num_layers); |
| 198 | + assert_eq!(output.shape(), &[hidden_size, 1]); |
| 199 | +} |
| 200 | + |
| 201 | +#[test] |
| 202 | +fn test_readme_linear_layer_example() { |
| 203 | + let hidden_size = 4; |
| 204 | + let num_classes = 3; |
| 205 | + |
| 206 | + // Create linear layer for classification: hidden_size -> num_classes |
| 207 | + let mut classifier = LinearLayer::new(hidden_size, num_classes); |
| 208 | + let mut optimizer = Adam::new(0.001); |
| 209 | + |
| 210 | + // Forward pass |
| 211 | + let lstm_output = Array2::ones((hidden_size, 1)); |
| 212 | + let logits = classifier.forward(&lstm_output); |
| 213 | + |
| 214 | + // Backward pass |
| 215 | + let grad_output = Array2::ones((num_classes, 1)); |
| 216 | + let (gradients, input_grad) = classifier.backward(&grad_output); |
| 217 | + classifier.update_parameters(&gradients, &mut optimizer, "classifier"); |
| 218 | + |
| 219 | + assert_eq!(logits.shape(), &[num_classes, 1]); |
| 220 | + assert_eq!(input_grad.shape(), &[hidden_size, 1]); |
| 221 | +} |
| 222 | + |
| 223 | +#[test] |
| 224 | +fn test_readme_advanced_learning_rate_scheduling_example() { |
| 225 | + // Create a network |
| 226 | + let network = LSTMNetwork::new(1, 4, 1); |
| 227 | + |
| 228 | + // Step decay: reduce LR by 50% every 10 epochs |
| 229 | + let mut step_trainer = create_step_lr_trainer(network.clone(), 0.01, 10, 0.5); |
| 230 | + |
| 231 | + // OneCycle policy for modern deep learning |
| 232 | + let mut one_cycle_trainer = create_one_cycle_trainer(network.clone(), 0.1, 100); |
| 233 | + |
| 234 | + // Cosine annealing with warm restarts |
| 235 | + let mut cosine_trainer = create_cosine_annealing_trainer(network.clone(), 0.01, 20, 1e-6); |
| 236 | + |
| 237 | + // Advanced combinations - Warmup + Cyclical scheduling |
| 238 | + let base_scheduler = CyclicalLR::new(0.001, 0.01, 10); |
| 239 | + let warmup_scheduler = WarmupScheduler::new(5, base_scheduler, 0.0001); |
| 240 | + let mut optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01); |
| 241 | + |
| 242 | + // Polynomial decay with visualization |
| 243 | + let poly_scheduler = PolynomialLR::new(100, 2.0, 0.001); |
| 244 | + let schedule = LRScheduleVisualizer::generate_schedule(poly_scheduler, 0.01, 100); |
| 245 | + |
| 246 | + step_trainer.optimizer.step(); |
| 247 | + one_cycle_trainer.optimizer.step(); |
| 248 | + cosine_trainer.optimizer.step(); |
| 249 | + optimizer.step(); |
| 250 | + |
| 251 | + assert_eq!(schedule.len(), 100); |
| 252 | + assert!(optimizer.get_current_lr() > 0.0); |
| 253 | +} |
| 254 | + |
122 | 255 | #[test] |
123 | 256 | fn test_dropout_types_example() { |
124 | 257 | // Standard dropout |
|
0 commit comments