diff --git a/README.md b/README.md index 84571c6..0ddfb97 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ Add to your `Cargo.toml`: ```toml [dependencies] -rust-lstm = "0.6" +rust-lstm = "0.8" ``` ### Basic Usage @@ -78,6 +78,7 @@ fn main() { ### Training Example ```rust +use ndarray::Array2; use rust_lstm::{LSTMNetwork, create_basic_trainer, TrainingConfig}; fn main() { @@ -94,8 +95,15 @@ fn main() { ..Default::default() }); - // Train (train_data is slice of (input_sequence, target_sequence) tuples) - // Each input_sequence and target_sequence is Vec> + // Train data is a slice of (input_sequence, target_sequence) tuples. + // Each input_sequence and target_sequence is Vec>. + let train_data = vec![( + vec![Array2::from_shape_vec((1, 1), vec![0.0]).unwrap()], + vec![Array2::from_shape_vec((10, 1), vec![0.0; 10]).unwrap()], + )]; + // Keep validation data separate from training data in real applications. + let validation_data = train_data.clone(); + trainer.train(&train_data, Some(&validation_data)); } ``` @@ -103,6 +111,7 @@ fn main() { ### Early Stopping ```rust +use ndarray::Array2; use rust_lstm::{ LSTMNetwork, create_basic_trainer, TrainingConfig, EarlyStoppingConfig, EarlyStoppingMetric @@ -128,6 +137,13 @@ fn main() { let mut trainer = create_basic_trainer(network, 0.001) .with_config(config); + let train_data = vec![( + vec![Array2::from_shape_vec((1, 1), vec![0.0]).unwrap()], + vec![Array2::from_shape_vec((10, 1), vec![0.0; 10]).unwrap()], + )]; + // Keep validation data separate from training data in real applications. + let validation_data = train_data.clone(); + // Training will stop early if validation loss stops improving trainer.train(&train_data, Some(&validation_data)); } @@ -136,7 +152,16 @@ fn main() { ### Bidirectional LSTM ```rust -use rust_lstm::layers::bilstm_network::{BiLSTMNetwork, CombineMode}; +use ndarray::Array2; +use rust_lstm::layers::bilstm_network::BiLSTMNetwork; + +let input_size = 3; +let hidden_size = 5; +let num_layers = 1; +let sequence = vec![ + Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap(), + Array2::from_shape_vec((input_size, 1), vec![0.2, -0.4, 0.7]).unwrap(), +]; // BiLSTM with concatenated outputs (output_size = 2 * hidden_size) let mut bilstm = BiLSTMNetwork::new_concat(input_size, hidden_size, num_layers); @@ -170,23 +195,38 @@ graph TD ### GRU Networks ```rust +use ndarray::Array2; use rust_lstm::models::gru_network::GRUNetwork; +let input_size = 3; +let hidden_size = 5; +let num_layers = 2; + // Create GRU network (alternative to LSTM) let mut gru = GRUNetwork::new(input_size, hidden_size, num_layers) .with_input_dropout(0.2, true) .with_recurrent_dropout(0.3, true); -// Forward pass -let (output, _) = gru.forward(&input, &hidden_state); +let input = Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap(); +let hidden_states = vec![Array2::zeros((hidden_size, 1)); num_layers]; + +// Forward pass returns one hidden state per layer +let outputs = gru.forward(&input, &hidden_states); +let output = outputs.last().unwrap(); ``` ### Linear Layer ```rust +use ndarray::Array2; use rust_lstm::layers::linear::LinearLayer; use rust_lstm::optimizers::Adam; +let hidden_size = 4; +let num_classes = 3; +let lstm_output = Array2::ones((hidden_size, 1)); +let grad_output = Array2::ones((num_classes, 1)); + // Create linear layer for classification: hidden_size -> num_classes let mut classifier = LinearLayer::new(hidden_size, num_classes); let mut optimizer = Adam::new(0.001); @@ -251,7 +291,7 @@ use rust_lstm::{ let network = LSTMNetwork::new(1, 10, 2); // Step decay: reduce LR by 50% every 10 epochs -let mut trainer = create_step_lr_trainer(network, 0.01, 10, 0.5); +let mut trainer = create_step_lr_trainer(network.clone(), 0.01, 10, 0.5); // OneCycle policy for modern deep learning let mut trainer = create_one_cycle_trainer(network.clone(), 0.1, 100); diff --git a/tests/readme_examples_test.rs b/tests/readme_examples_test.rs index f931d0d..a01ff51 100644 --- a/tests/readme_examples_test.rs +++ b/tests/readme_examples_test.rs @@ -8,12 +8,20 @@ use ndarray::Array2; use rust_lstm::{ + layers::bilstm_network::BiLSTMNetwork, layers::dropout::{Dropout, Zoneout}, + layers::linear::LinearLayer, layers::peephole_lstm_cell::PeepholeLSTMCell, loss::{CrossEntropyLoss, MAELoss, MSELoss}, - optimizers::{Adam, RMSprop, SGD}, - training::create_basic_trainer, - LSTMNetwork, LSTMTrainer, LayerDropoutConfig, TrainingConfig, + models::gru_network::GRUNetwork, + optimizers::{Adam, RMSprop, ScheduledOptimizer, SGD}, + schedulers::{CyclicalLR, LRScheduleVisualizer, PolynomialLR, WarmupScheduler}, + training::{ + create_basic_trainer, create_cosine_annealing_trainer, create_one_cycle_trainer, + create_step_lr_trainer, + }, + EarlyStoppingConfig, EarlyStoppingMetric, LSTMNetwork, LSTMTrainer, LayerDropoutConfig, + TrainingConfig, }; #[test] @@ -119,6 +127,131 @@ fn test_training_example() { assert_eq!(predictions[0].shape(), &[4, 1]); } +#[test] +fn test_readme_early_stopping_example() { + let network = LSTMNetwork::new(1, 4, 1); + + // Configure early stopping + let early_stopping = EarlyStoppingConfig { + patience: 2, + min_delta: 1e-4, + restore_best_weights: true, + monitor: EarlyStoppingMetric::ValidationLoss, + }; + + let config = TrainingConfig { + epochs: 2, + early_stopping: Some(early_stopping), + ..Default::default() + }; + + let mut trainer = create_basic_trainer(network, 0.001).with_config(config); + let train_data = generate_test_data(); + let validation_data = generate_test_data(); + + trainer.train(&train_data, Some(&validation_data)); + + assert_eq!(trainer.config.early_stopping.as_ref().unwrap().patience, 2); +} + +#[test] +fn test_readme_bilstm_example() { + let input_size = 3; + let hidden_size = 5; + let num_layers = 1; + + // BiLSTM with concatenated outputs (output_size = 2 * hidden_size) + let mut bilstm = BiLSTMNetwork::new_concat(input_size, hidden_size, num_layers); + + // Process sequence with both past and future context + let sequence = vec![ + Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap(), + Array2::from_shape_vec((input_size, 1), vec![0.2, -0.4, 0.7]).unwrap(), + ]; + let outputs = bilstm.forward_sequence(&sequence); + + assert_eq!(outputs.len(), sequence.len()); + for output in outputs { + assert_eq!(output.shape(), &[2 * hidden_size, 1]); + } +} + +#[test] +fn test_readme_gru_example() { + let input_size = 3; + let hidden_size = 5; + let num_layers = 2; + + // Create GRU network (alternative to LSTM) + let mut gru = GRUNetwork::new(input_size, hidden_size, num_layers) + .with_input_dropout(0.2, true) + .with_recurrent_dropout(0.3, true); + + let input = Array2::from_shape_vec((input_size, 1), vec![0.5, 0.1, -0.3]).unwrap(); + let hidden_states = vec![Array2::zeros((hidden_size, 1)); num_layers]; + + // Forward pass returns one hidden state per layer + let outputs = gru.forward(&input, &hidden_states); + let output = outputs.last().unwrap(); + + assert_eq!(outputs.len(), num_layers); + assert_eq!(output.shape(), &[hidden_size, 1]); +} + +#[test] +fn test_readme_linear_layer_example() { + let hidden_size = 4; + let num_classes = 3; + + // Create linear layer for classification: hidden_size -> num_classes + let mut classifier = LinearLayer::new(hidden_size, num_classes); + let mut optimizer = Adam::new(0.001); + + // Forward pass + let lstm_output = Array2::ones((hidden_size, 1)); + let logits = classifier.forward(&lstm_output); + + // Backward pass + let grad_output = Array2::ones((num_classes, 1)); + let (gradients, input_grad) = classifier.backward(&grad_output); + classifier.update_parameters(&gradients, &mut optimizer, "classifier"); + + assert_eq!(logits.shape(), &[num_classes, 1]); + assert_eq!(input_grad.shape(), &[hidden_size, 1]); +} + +#[test] +fn test_readme_advanced_learning_rate_scheduling_example() { + // Create a network + let network = LSTMNetwork::new(1, 4, 1); + + // Step decay: reduce LR by 50% every 10 epochs + let mut step_trainer = create_step_lr_trainer(network.clone(), 0.01, 10, 0.5); + + // OneCycle policy for modern deep learning + let mut one_cycle_trainer = create_one_cycle_trainer(network.clone(), 0.1, 100); + + // Cosine annealing with warm restarts + let mut cosine_trainer = create_cosine_annealing_trainer(network.clone(), 0.01, 20, 1e-6); + + // Advanced combinations - Warmup + Cyclical scheduling + let base_scheduler = CyclicalLR::new(0.001, 0.01, 10); + let warmup_scheduler = WarmupScheduler::new(5, base_scheduler, 0.0001); + let mut optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01); + + // Polynomial decay with visualization + let poly_scheduler = PolynomialLR::new(100, 2.0, 0.001); + let schedule = LRScheduleVisualizer::generate_schedule(poly_scheduler, 0.01, 100); + + step_trainer.optimizer.step(); + one_cycle_trainer.optimizer.step(); + cosine_trainer.optimizer.step(); + optimizer.step(); + + assert_eq!(schedule.len(), 100); + assert!(optimizer.get_current_lr() > 0.0); +} + #[test] fn test_dropout_types_example() { // Standard dropout